diff --git a/.github/workflows/basic-install.yml b/.github/workflows/basic-install.yml index a71bfc18e..9d71d0aab 100644 --- a/.github/workflows/basic-install.yml +++ b/.github/workflows/basic-install.yml @@ -20,8 +20,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - # disable windows build test as bilby_cython is currently broken there - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -52,8 +51,8 @@ jobs: python -c "import bilby.hyper" python -c "import cli_bilby" python test/import_test.py - # - if: ${{ matrix.os != "windows-latest" }} - # run: | - # for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do - # ${script} --help; - # done + - if: runner.os != 'Windows' + run: | + for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do + ${script} --help; + done diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 030526bc5..081e16456 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -62,6 +62,15 @@ jobs: - name: Run unit tests run: | python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml + - name: Run jax-backend unit tests + run: | + python -m pip install .[jax] + SCIPY_ARRAY_API=1 pytest --array-backend jax --durations 10 + - name: Run torch-backend unit tests + # there are scipy version issues with python 3.10 and torch + if: matrix.python.version > 3.10 + run: | + SCIPY_ARRAY_API=1 pytest --array-backend torch --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/bilby/compat/__init__.py b/bilby/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py new file mode 100644 index 000000000..af0699147 --- /dev/null +++ b/bilby/compat/jax.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +from ..core.likelihood import Likelihood + + +class JittedLikelihood(Likelihood): + """ + A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to wrap. + cast_to_float: bool + Whether to return a float instead of a :code:`jax.Array`. + """ + + def __init__(self, likelihood, cast_to_float=True): + self._likelihood = likelihood + self._ll = jax.jit(likelihood.log_likelihood) + self._llr = jax.jit(likelihood.log_likelihood_ratio) + self.cast_to_float = cast_to_float + super().__init__() + + def __getattr__(self, name): + return getattr(self._likelihood, name) + + def log_likelihood(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._ll(parameters) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l + + def log_likelihood_ratio(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._llr(parameters) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py new file mode 100644 index 000000000..53345abd4 --- /dev/null +++ b/bilby/compat/patches.py @@ -0,0 +1,50 @@ +import array_api_compat as aac + +from .utils import BackendNotImplementedError + + +def erfinv_import(xp): + if aac.is_numpy_namespace(xp): + from scipy.special import erfinv + elif aac.is_jax_namespace(xp): + from jax.scipy.special import erfinv + elif aac.is_torch_namespace(xp): + from torch.special import erfinv + elif aac.is_cupy_namespace(xp): + from cupyx.scipy.special import erfinv + else: + raise BackendNotImplementedError + return erfinv + + +def multivariate_logpdf(xp, mean, cov): + if aac.is_numpy_namespace(xp): + from scipy.stats import multivariate_normal + + logpdf = multivariate_normal(mean=mean, cov=cov).logpdf + elif aac.is_jax_namespace(xp): + from functools import partial + from jax.scipy.stats.multivariate_normal import logpdf + + logpdf = partial(logpdf, mean=mean, cov=cov) + elif aac.is_torch_namespace(xp): + from torch.distributions.multivariate_normal import MultivariateNormal + + mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov)) + logpdf = mvn.log_prob + else: + raise BackendNotImplementedError + return logpdf + + +def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=None): + if xp is None: + xp = aac.get_namespace(a) + + # the scipy version of logsumexp cannot be vmapped + if aac.is_jax_namespace(xp): + from jax.scipy.special import logsumexp as lse + else: + from scipy.special import logsumexp as lse + + return lse(a=a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) diff --git a/bilby/compat/types.py b/bilby/compat/types.py new file mode 100644 index 000000000..48c74c29f --- /dev/null +++ b/bilby/compat/types.py @@ -0,0 +1,4 @@ +import numpy as np + +Real = float | int +ArrayLike = np.ndarray | list | tuple diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py new file mode 100644 index 000000000..a05cc0920 --- /dev/null +++ b/bilby/compat/utils.py @@ -0,0 +1,105 @@ +import inspect +from collections.abc import Iterable + +import numpy as np +from array_api_compat import array_namespace + +from ..core.utils.log import logger + +__all__ = ["array_module", "promote_to_array"] + + +def array_module(arr): + if isinstance(arr, tuple) and len(arr) == 1: + arr = arr[0] + try: + return array_namespace(arr) + except TypeError: + if isinstance(arr, dict): + try: + return array_namespace(*[val for val in arr.values() if not isinstance(val, str)]) + except TypeError: + return np + elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable): + try: + return array_namespace(*arr) + except TypeError: + return np + elif arr.__class__.__module__ == "builtins": + return np + elif arr.__module__.startswith("pandas"): + return np + else: + logger.warning( + f"Unknown array module for type: {type(arr)} Defaulting to numpy." + ) + return np + + +def promote_to_array(args, backend, skip=None): + if skip is None: + skip = len(args) + else: + skip = len(args) - skip + if backend.__name__ != "numpy": + args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:] + return args + + +def xp_wrap(func, no_xp=False): + """ + A decorator that will figure out the array module from the input + arguments and pass it to the function as the 'xp' keyword argument. + + Parameters + ========== + func: function + The function to be decorated. + no_xp: bool + If True, the decorator will not attempt to add the 'xp' keyword + argument and so the wrapper is a no-op. + + Returns + ======= + function + The decorated function. + """ + def parse_args_kwargs_for_xp(*args, xp=None, **kwargs): + if not no_xp and xp is None: + try: + # if the user specified the target arrays in kwargs + # we need to be able to support this, if there is + # only one kwargs, pass it through alone, this is + # sometimes a dictionary of arrays so this is needed + # to remove a level of nesting + if len(args) > 0: + xp = array_module(args) + elif len(kwargs) == 1: + xp = array_module(next(iter(kwargs.values()))) + elif len(kwargs) > 1: + xp = array_module(kwargs) + else: + xp = np + kwargs["xp"] = xp + except TypeError as e: + print("type failed", e) + kwargs["xp"] = np + elif not no_xp: + kwargs["xp"] = xp + return args, kwargs + + sig = inspect.signature(func) + if any(name in sig.parameters for name in ("self", "cls")): + def wrapped(self, *args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(self, *args, **kwargs) + else: + def wrapped(*args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(*args, **kwargs) + + return wrapped + + +class BackendNotImplementedError(NotImplementedError): + pass diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 0d103d4cc..12dfa2bf6 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -1,6 +1,8 @@ import json import os +from copy import copy +import array_api_compat as aac import numpy as np from .likelihood import _safe_likelihood_call @@ -10,6 +12,7 @@ BilbyJsonEncoder, load_json, move_old_file ) from .result import FileMovedError +from ..compat.utils import array_module def grid_file_name(outdir, label, gzip=False): @@ -36,8 +39,11 @@ def grid_file_name(outdir, label, gzip=False): class Grid(object): - def __init__(self, likelihood=None, priors=None, grid_size=101, - save=False, label='no_label', outdir='.', gzip=False): + def __init__( + self, likelihood=None, priors=None, grid_size=101, + save=False, label='no_label', outdir='.', gzip=False, + xp=None, + ): """ Parameters @@ -58,8 +64,16 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, The output directory to which the grid will be saved gzip: bool Set whether to gzip the output grid file + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`cupy`). If :code:`None`, defaults to :code:`numpy`. + """ + if xp is None: + xp = np + logger.debug("No array module given for grid, defaulting to numpy.") + if priors is None: priors = dict() self.likelihood = likelihood @@ -68,13 +82,15 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, self.parameter_names = list(self.priors.keys()) self.sample_points = dict() - self._get_sample_points(grid_size) + self._get_sample_points(grid_size, xp=xp) # evaluate the prior on the grid points if self.n_dims > 0: self._ln_prior = self.priors.ln_prob( {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)}, axis=0).reshape( self.mesh_grid[0].shape) + else: + self._ln_prior = xp.asarray(0.0) self._ln_likelihood = None # evaluate the likelihood on the grid points @@ -97,12 +113,14 @@ def ln_prior(self): @property def prior(self): - return np.exp(self.ln_prior) + lnp = self.ln_prior + xp = array_module(lnp) + return xp.exp(lnp) @property def ln_likelihood(self): if self._ln_likelihood is None: - self._evaluate() + self._evaluate(xp=array_module(self._ln_prior)) return self._ln_likelihood @property @@ -116,7 +134,8 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. parameters: list, str A list, or single string, of parameters to marginalize over. If None then all parameters will be marginalized over. @@ -151,7 +170,7 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): else: raise TypeError("Parameters names must be a list or string") - out_array = log_array.copy() + out_array = copy(log_array) names = list(self.parameter_names) for name in params: @@ -166,7 +185,8 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. name: str The name of the parameter to marginalize over. non_marg_names: list @@ -189,17 +209,26 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): non_marg_names.remove(name) places = self.sample_points[name] + xp = aac.get_namespace(log_array) if len(places) > 1: - dx = np.diff(places) - out = np.apply_along_axis( - logtrapzexp, axis, log_array, dx - ) + dx = xp.diff(places) + if log_array.ndim == 1: + out = logtrapzexp(log_array, dx=dx, xp=xp) + elif aac.is_torch_namespace(xp): + # https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440 + out = xp.stack([ + logtrapzexp(x_i, dx=dx, xp=xp) for x_i in xp.unbind(log_array, dim=axis) + ], dim=min(axis, log_array.ndim - 2)) + else: + out = xp.apply_along_axis( + logtrapzexp, axis, log_array, dx + ) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape - q = np.arange(0, len(z)).astype(int) != axis - out = np.reshape(log_array, tuple((np.array(list(z)))[q])) + q = xp.arange(0, len(z)).astype(int) != axis + out = xp.reshape(log_array, tuple((xp.asarray(list(z)))[q])) return out @@ -277,8 +306,9 @@ def marginalize_likelihood(self, parameters=None, not_parameters=None): """ ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) + xp = aac.get_namespace(ln_like) # NOTE: the output will not be properly normalised - return np.exp(ln_like - np.max(ln_like)) + return xp.exp(ln_like - xp.max(ln_like)) def marginalize_posterior(self, parameters=None, not_parameters=None): """ @@ -301,20 +331,33 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): ln_post = self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised - return np.exp(ln_post - np.max(ln_post)) + xp = aac.get_namespace(ln_post) + return xp.exp(ln_post - xp.max(ln_post)) def _evaluate(self): - self._ln_likelihood = np.empty(self.mesh_grid[0].shape) - self._evaluate_recursion(0, parameters=dict()) + xp = aac.get_namespace(self.mesh_grid[0]) + if aac.is_torch_namespace(xp) or aac.is_jax_namespace(xp): + if aac.is_torch_namespace(xp): + from torch import vmap + else: + from jax import vmap + self._ln_likelihood = vmap(self.likelihood.log_likelihood)( + {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)} + ).reshape(self.mesh_grid[0].shape) + + else: + self._ln_likelihood = xp.empty(self.mesh_grid[0].shape) + self._evaluate_recursion(0, parameters=dict()) self.ln_noise_evidence = self.likelihood.noise_log_likelihood() def _evaluate_recursion(self, dimension, parameters): if dimension == self.n_dims: - current_point = tuple([[int(np.where( + xp = aac.get_namespace(self.mesh_grid[0]) + current_point = tuple([[xp.where( parameters[name] == - self.sample_points[name])[0])] for name in self.parameter_names]) - self._ln_likelihood[current_point] = _safe_likelihood_call( - self.likelihood, parameters + self.sample_points[name])[0].item()] for name in self.parameter_names]) + self._ln_likelihood[current_point] = ( + _safe_likelihood_call(self.likelihood, parameters) ) else: name = self.parameter_names[dimension] @@ -322,29 +365,29 @@ def _evaluate_recursion(self, dimension, parameters): parameters[name] = self.sample_points[name][ii] self._evaluate_recursion(dimension + 1, parameters) - def _get_sample_points(self, grid_size): + def _get_sample_points(self, grid_size, *, xp=np): for ii, key in enumerate(self.parameter_names): if isinstance(self.priors[key], Prior): if isinstance(grid_size, int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size)) + xp.linspace(0, 1, grid_size)) elif isinstance(grid_size, list): if isinstance(grid_size[ii], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[ii])) + xp.linspace(0, 1, grid_size[ii])) else: - self.sample_points[key] = grid_size[ii] + self.sample_points[key] = xp.asarray(grid_size[ii]) elif isinstance(grid_size, dict): if isinstance(grid_size[key], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[key])) + xp.linspace(0, 1, grid_size[key])) else: - self.sample_points[key] = grid_size[key] + self.sample_points[key] = xp.asarray(grid_size[key]) else: raise TypeError("Unrecognized 'grid_size' type") # set the mesh of points - self.mesh_grid = np.meshgrid( + self.mesh_grid = xp.meshgrid( *(self.sample_points[key] for key in self.parameter_names), indexing='ij') @@ -420,7 +463,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, "following message:\n {} \n\n".format(e)) @classmethod - def read(cls, filename=None, outdir=None, label=None, gzip=False): + def read(cls, filename=None, outdir=None, label=None, gzip=False, xp=None): """ Read in a saved .json grid file Parameters @@ -433,6 +476,9 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): If given, whether the file is gzipped or not (only required if the file is gzipped, but does not have the standard '.gz' file extension) + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`jax.numpy`). If :code:`None`, defaults to :code:`numpy`. Returns ======= @@ -456,7 +502,7 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): try: grid = cls(likelihood=None, priors=dictionary['priors'], grid_size=dictionary['sample_points'], - label=dictionary['label'], outdir=dictionary['outdir']) + label=dictionary['label'], outdir=dictionary['outdir'], xp=xp) # set the likelihood grid._ln_likelihood = dictionary['ln_likelihood'] diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 4d75033b9..cc2f3a854 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -3,11 +3,14 @@ import os import warnings +import array_api_compat as aac import numpy as np +from array_api_compat import is_array_api_obj from scipy.special import gammaln, xlogy -from scipy.stats import multivariate_normal from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args, logger +from ..compat.patches import multivariate_logpdf +from ..compat.utils import BackendNotImplementedError, array_module PARAMETERS_AS_STATE = os.environ.get("BILBY_ALLOW_PARAMETERS_AS_STATE", "WARN") for msg in [ @@ -193,7 +196,7 @@ class ZeroLikelihood(Likelihood): def __init__(self, likelihood): super(ZeroLikelihood, self).__init__() - self.parameters = likelihood.parameters + self.parameters = dict() self._parent = likelihood def log_likelihood(self, parameters=None): @@ -309,9 +312,10 @@ def __init__(self, x, y, func, sigma=None, **kwargs): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) + xp = array_module(self.x) sigma = parameters.get("sigma", self.sigma) - log_l = np.sum(- (self.residual(parameters) / sigma)**2 / 2 - - np.log(2 * np.pi * sigma**2) / 2) + log_l = xp.sum(- (self.residual(parameters) / sigma)**2 / 2 - + xp.log(xp.asarray(2 * np.pi * sigma**2)) / 2) return log_l def __repr__(self): @@ -370,17 +374,18 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): rate = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if not isinstance(rate, np.ndarray): + if not is_array_api_obj(rate): raise ValueError( "Poisson rate function returns wrong value type! " "Is {} when it should be numpy.ndarray".format(type(rate))) - elif np.any(rate < 0.): + xp = aac.get_namespace(rate) + if xp.any(rate < 0.): raise ValueError(("Poisson rate function returns a negative", " value!")) - elif np.any(rate == 0.): + elif xp.any(rate == 0.): return -np.inf else: - return np.sum(-rate + self.y * np.log(rate) - gammaln(self.y + 1)) + return xp.sum(-rate + self.y * xp.log(rate) - gammaln(self.y + 1)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -392,10 +397,12 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = aac.get_namespace(y) # check array is a non-negative integer array - if y.dtype.kind not in 'ui' or np.any(y < 0): + # torch doesn't support checking dtype kind + if (not aac.is_torch_namespace(xp) and y.dtype.kind not in 'ui') or xp.any(y < 0): raise ValueError("Data must be non-negative integers") self.__y = y @@ -421,9 +428,10 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if np.any(mu < 0.): + xp = aac.get_namespace(mu) + if xp.any(mu < 0.): return -np.inf - return -np.sum(np.log(mu) + (self.y / mu)) + return -xp.sum(xp.log(mu) + (self.y / mu)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -435,9 +443,10 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) - if np.any(y < 0): + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = aac.get_namespace(y) + if xp.any(y < 0): raise ValueError("Data must be non-negative") self._y = y @@ -484,9 +493,10 @@ def log_likelihood(self, parameters=None): raise ValueError("Number of degrees of freedom for Student's " "t-likelihood must be positive") + xp = array_module(self.x) log_l =\ - np.sum(- (nu + 1) * np.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + - np.log(self.lam / (nu * np.pi)) / 2 + + xp.sum(- (nu + 1) * xp.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + + xp.log(xp.asarray(self.lam / (nu * np.pi))) / 2 + gammaln((nu + 1) / 2) - gammaln(nu / 2)) return log_l @@ -533,8 +543,10 @@ def __init__(self, data, n_dimensions, base="parameter_"): base: str The base of the parameter labels """ - self.data = np.array(data) - self._total = np.sum(self.data) + if not is_array_api_obj(data): + data = np.array(data) + self.data = data + self._total = self.data.sum() super(Multinomial, self).__init__() self.n = n_dimensions self.base = base @@ -561,7 +573,8 @@ def noise_log_likelihood(self): def _multinomial_ln_pdf(self, probs): """Lifted from scipy.stats.multinomial._logpdf""" - ln_prob = gammaln(self._total + 1) + np.sum( + xp = array_module(self.data) + ln_prob = gammaln(self._total + 1) + xp.sum( xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1) return ln_prob @@ -580,10 +593,17 @@ class AnalyticalMultidimensionalCovariantGaussian(Likelihood): """ def __init__(self, mean, cov): - self.cov = np.atleast_2d(cov) - self.mean = np.atleast_1d(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.mean = xp.atleast_1d(mean) + self.sigma = xp.sqrt(xp.diag(self.cov)) + try: + self.logpdf = multivariate_logpdf(xp, mean=self.mean, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) + super(AnalyticalMultidimensionalCovariantGaussian, self).__init__() @property @@ -592,8 +612,9 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + xp = array_module(self.cov) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) + return self.logpdf(x) class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): @@ -611,12 +632,18 @@ class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): """ def __init__(self, mean_1, mean_2, cov): - self.cov = np.atleast_2d(cov) - self.sigma = np.sqrt(np.diag(self.cov)) - self.mean_1 = np.atleast_1d(mean_1) - self.mean_2 = np.atleast_1d(mean_2) - self.pdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov) - self.pdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov) + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.sigma = xp.sqrt(xp.diag(self.cov)) + self.mean_1 = xp.atleast_1d(mean_1) + self.mean_2 = xp.atleast_1d(mean_2) + try: + self.logpdf_1 = multivariate_logpdf(xp, mean=self.mean_1, cov=self.cov) + self.logpdf_2 = multivariate_logpdf(xp, mean=self.mean_2, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__() @property @@ -625,8 +652,9 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return -np.log(2) + np.logaddexp(self.pdf_1.logpdf(x), self.pdf_2.logpdf(x)) + xp = array_module(self.cov) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) + return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x)) class JointLikelihood(Likelihood): diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index bc47cf680..bec049d90 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1,21 +1,25 @@ +import os + import numpy as np +os.environ["SCIPY_ARRAY_API"] = "1" # noqa # flag for scipy backend switching from scipy.special import ( - xlogy, - erf, - erfinv, - log1p, - stdtrit, - gammaln, - stdtr, - betaln, betainc, betaincinv, + betaln, + erf, gammaincinv, gammainc, + gammaln, + stdtr, + stdtrit, + xlogy, + xlog1py, ) from .base import Prior from ..utils import logger +from ...compat.patches import erfinv_import +from ...compat.utils import BackendNotImplementedError, array_module, xp_wrap class DeltaFunction(Prior): @@ -41,7 +45,7 @@ def __init__(self, peak, name=None, latex_label=None, unit=None): self._is_fixed = True self.least_recently_sampled = peak - def rescale(self, val): + def rescale(self, val, *, xp=None): """Rescale everything to the peak with the correct shape. Parameters @@ -54,7 +58,7 @@ def rescale(self, val): """ return self.peak * val ** 0 - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -67,10 +71,10 @@ def prob(self, val): """ at_peak = (val == self.peak) - return np.nan_to_num(np.multiply(at_peak, np.inf)) + return at_peak * 1.0 - def cdf(self, val): - return np.ones_like(val) * (val > self.peak) + def cdf(self, val, *, xp=None): + return 1.0 * (val > self.peak) class PowerLaw(Prior): @@ -101,7 +105,8 @@ def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, boundary=boundary) self.alpha = alpha - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -117,12 +122,13 @@ def rescale(self, val): Union[float, array_like]: Rescaled probability """ if self.alpha == -1: - return self.minimum * np.exp(val * np.log(self.maximum / self.minimum)) + return self.minimum * xp.exp(val * xp.log(xp.asarray(self.maximum / self.minimum))) else: return (self.minimum ** (1 + self.alpha) + val * (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -134,13 +140,16 @@ def prob(self, val): float: Prior probability of val """ if self.alpha == -1: - return np.nan_to_num(1 / val / np.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) + return xp.nan_to_num( + 1 / val / xp.log(xp.asarray(self.maximum / self.minimum)) + ) * self.is_in_prior_range(val) else: - return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / + return xp.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) * self.is_in_prior_range(val) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -153,28 +162,29 @@ def ln_prob(self, val): """ if self.alpha == -1: - normalising = 1. / np.log(self.maximum / self.minimum) + normalising = 1. / xp.log(xp.asarray(self.maximum / self.minimum)) else: - normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - - self.minimum ** (1 + self.alpha)) + normalising = (1 + self.alpha) / xp.asarray( + self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha) + ) with np.errstate(divide='ignore', invalid='ignore'): - ln_in_range = np.log(1. * self.is_in_prior_range(val)) - ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising) + ln_in_range = xp.log(1. * self.is_in_prior_range(val)) + ln_p = self.alpha * xp.nan_to_num(xp.log(val)) + xp.log(normalising) return ln_p + ln_in_range - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): if self.alpha == -1: - _cdf = (np.log(val / self.minimum) / - np.log(self.maximum / self.minimum)) + with np.errstate(invalid="ignore"): + _cdf = xp.log(val / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) else: _cdf = ( (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) ) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) + _cdf = xp.clip(_cdf, 0, 1) return _cdf @@ -203,7 +213,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -220,7 +230,7 @@ def rescale(self, val): """ return self.minimum + val * (self.maximum - self.minimum) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -233,7 +243,8 @@ def prob(self, val): """ return ((val >= self.minimum) & (val <= self.maximum)) / (self.maximum - self.minimum) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val Parameters @@ -244,13 +255,13 @@ def ln_prob(self, val): ======= float: log probability of val """ - return xlogy(1, (val >= self.minimum) & (val <= self.maximum)) - xlogy(1, self.maximum - self.minimum) + with np.errstate(divide="ignore"): + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = (val - self.minimum) / (self.maximum - self.minimum) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) - return _cdf + return xp.clip(_cdf, 0, 1) class LogUniform(PowerLaw): @@ -310,7 +321,8 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -325,21 +337,14 @@ def rescale(self, val): ======= Union[float, array_like]: Rescaled probability """ - if isinstance(val, (float, int)): - if val < 0.5: - return -self.maximum * np.exp(-2 * val * np.log(self.maximum / self.minimum)) - else: - return self.minimum * np.exp(np.log(self.maximum / self.minimum) * (2 * val - 1)) - else: - vals_less_than_5 = val < 0.5 - rescaled = np.empty_like(val) - rescaled[vals_less_than_5] = -self.maximum * np.exp(-2 * val[vals_less_than_5] * - np.log(self.maximum / self.minimum)) - rescaled[~vals_less_than_5] = self.minimum * np.exp(np.log(self.maximum / self.minimum) * - (2 * val[~vals_less_than_5] - 1)) - return rescaled - - def prob(self, val): + return ( + xp.sign(2 * val - 1) + * self.minimum + * xp.exp(xp.abs(2 * val - 1) * xp.log(xp.asarray(self.maximum / self.minimum))) + ) + + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -350,11 +355,12 @@ def prob(self, val): ======= float: Prior probability of val """ - val = np.abs(val) - return (np.nan_to_num(0.5 / val / np.log(self.maximum / self.minimum)) * + val = xp.abs(val) + return (xp.nan_to_num(0.5 / val / xp.log(xp.asarray(self.maximum / self.minimum))) * self.is_in_prior_range(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -366,19 +372,12 @@ def ln_prob(self, val): float: """ - return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) + return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(xp.asarray(self.maximum / self.minimum)))) - def cdf(self, val): - norm = 0.5 / np.log(self.maximum / self.minimum) - _cdf = ( - -norm * np.log(abs(val) / self.maximum) - * (val <= -self.minimum) * (val >= -self.maximum) - + (0.5 + norm * np.log(abs(val) / self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 0.5 * (val > -self.minimum) * (val < self.minimum) - + 1 * (val > self.maximum) - ) - return _cdf + @xp_wrap + def cdf(self, val, *, xp=None): + asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) + return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) class Cosine(Prior): @@ -405,16 +404,18 @@ def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, super(Cosine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in cosine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.sin(self.maximum) - np.sin(self.minimum)) - return np.arcsin(val / norm + np.sin(self.minimum)) + norm = 1 / (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) + return xp.arcsin(val / norm + xp.sin(xp.asarray(self.minimum))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [-pi/2, pi/2]. Parameters @@ -425,15 +426,17 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.cos(val) / 2 * self.is_in_prior_range(val) + return xp.cos(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = ( - (np.sin(val) - np.sin(self.minimum)) - / (np.sin(self.maximum) - np.sin(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.sin(val) - xp.sin(xp.asarray(self.minimum))) / + (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -461,16 +464,18 @@ def __init__(self, minimum=0, maximum=np.pi, name=None, super(Sine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in sine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.cos(self.minimum) - np.cos(self.maximum)) - return np.arccos(np.cos(self.minimum) - val / norm) + norm = 1 / (xp.cos(xp.asarray(self.minimum)) - xp.cos(xp.asarray(self.maximum))) + return xp.arccos(xp.cos(xp.asarray(self.minimum)) - val / norm) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [0, pi]. Parameters @@ -481,15 +486,17 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.sin(val) / 2 * self.is_in_prior_range(val) + return xp.sin(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): _cdf = ( - (np.cos(val) - np.cos(self.minimum)) - / (np.cos(self.maximum) - np.cos(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.cos(val) - xp.cos(xp.asarray(self.minimum))) + / (xp.cos(xp.asarray(self.maximum)) - xp.cos(xp.asarray(self.minimum))) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -517,7 +524,8 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gaussian prior. @@ -527,9 +535,14 @@ def rescale(self, val): This maps to the inverse CDF. This has been analytically solved for this case. """ + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError(f"Gaussian prior rescale not implemented for this {xp.__name__}") return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -540,9 +553,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Parameters @@ -553,10 +567,9 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ + return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(xp.asarray(2 * np.pi * self.sigma ** 2))) - return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2)) - - def cdf(self, val): + def cdf(self, val, *, xp=None): return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 @@ -607,16 +620,24 @@ def normalisation(self): return (erf((self.maximum - self.mu) / 2 ** 0.5 / self.sigma) - erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate truncated Gaussian prior. This maps to the inverse CDF. This has been analytically solved for this case. """ + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError( + f"Truncated Gaussian prior rescale not implemented for this {xp.__name__}" + ) return erfinv(2 * val * self.normalisation + erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -627,17 +648,15 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ / self.sigma / self.normalisation * self.is_in_prior_range(val) - def cdf(self, val): - _cdf = ( - ( - erf((val - self.mu) / 2 ** 0.5 / self.sigma) - - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) - ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + def cdf(self, val, *, xp=None): + _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( + (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -701,15 +720,21 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate LogNormal prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return np.exp(self.mu + np.sqrt(2 * self.sigma ** 2) * erfinv(2 * val - 1)) + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError(f"LogNormal prior rescale not implemented for this {xp.__name__}") + return xp.exp(self.mu + (2 * self.sigma ** 2)**0.5 * erfinv(2 * val - 1)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Returns the prior probability of val. Parameters @@ -720,20 +745,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _prob = 0. - else: - _prob = np.exp(-(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val / self.sigma - else: - _prob = np.zeros(val.size) - idx = (val > self.minimum) - _prob[idx] = np.exp(-(np.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val[idx] / self.sigma - return _prob + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -744,30 +759,18 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = -(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ - - np.log(np.sqrt(2 * np.pi) * val * self.sigma) - else: - _ln_prob = -np.inf * np.ones(val.size) - idx = (val > self.minimum) - _ln_prob[idx] = -(np.log(val[idx]) - self.mu) ** 2\ - / self.sigma ** 2 / 2 - np.log(np.sqrt(2 * np.pi) * val[idx] * self.sigma) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val <= self.minimum: - _cdf = 0. - else: - _cdf = 0.5 + erf((np.log(val) - self.mu) / self.sigma / np.sqrt(2)) / 2 - else: - _cdf = np.zeros(val.size) - _cdf[val > self.minimum] = 0.5 + erf(( - np.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 - return _cdf + with np.errstate(divide="ignore", invalid="ignore"): + return xp.nan_to_num(( + -(xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) ** 2 / self.sigma ** 2 / 2 + - xp.log((2 * np.pi)**0.5 * val * self.sigma) + ) + xp.log(val > self.minimum), nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + with np.errstate(divide="ignore"): + return 0.5 + erf( + (xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2) + ) / 2 class LogGaussian(LogNormal): @@ -795,15 +798,18 @@ def __init__(self, mu, name=None, latex_label=None, unit=None, boundary=None): unit=unit, boundary=boundary) self.mu = mu - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Exponential prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return -self.mu * log1p(-val) + with np.errstate(divide="ignore", over="ignore"): + return -self.mu * xp.log1p(-val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -814,17 +820,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _prob = 0. - else: - _prob = np.exp(-val / self.mu) / self.mu - else: - _prob = np.zeros(val.size) - _prob[val >= self.minimum] = np.exp(-val[val >= self.minimum] / self.mu) / self.mu - return _prob + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -835,26 +834,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = -val / self.mu - np.log(self.mu) - else: - _ln_prob = -np.inf * np.ones(val.size) - _ln_prob[val >= self.minimum] = -val[val >= self.minimum] / self.mu - np.log(self.mu) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = 1. - np.exp(-val / self.mu) - else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) - return _cdf + with np.errstate(divide="ignore"): + return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) + + @xp_wrap + def cdf(self, val, *, xp=None): + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + return xp.maximum(1. - xp.exp(-val / self.mu), xp.asarray(0.0)) class StudentT(Prior): @@ -891,26 +877,26 @@ def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Student's t-prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - rescaled[val == 0] = -np.inf - rescaled[val == 1] = np.inf - return rescaled + with np.errstate(divide="ignore", invalid="ignore"): + return ( + xp.nan_to_num(stdtrit(self.df, val) * self.scale + self.mu) + + xp.log(val > 0) + - xp.log(val < 1) + ) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -921,9 +907,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -934,11 +921,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df)\ - - np.log(np.sqrt(np.pi * self.df) * self.scale) - (self.df + 1) / 2 *\ - np.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + return ( + gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df) + - xp.log(xp.asarray((np.pi * self.df)**0.5 * self.scale)) - (self.df + 1) / 2 + * xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + ) - def cdf(self, val): + def cdf(self, val, *, xp=None): return stdtr(self.df, (val - self.mu) / self.scale) @@ -980,15 +969,25 @@ def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Beta prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - return betaincinv(self.alpha, self.beta, val) * (self.maximum - self.minimum) + self.minimum + return ( + xp.asarray(betaincinv(xp.asarray(self.alpha), xp.asarray(self.beta), val)) + * (self.maximum - self.minimum) + + self.minimum + ) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -999,9 +998,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1012,37 +1012,19 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - _ln_prob = xlogy(self.alpha - 1, val - self.minimum) + xlogy(self.beta - 1, self.maximum - val)\ - - betaln(self.alpha, self.beta) - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) - - # deal with the fact that if alpha or beta are < 1 you get infinities at 0 and 1 - if isinstance(val, (float, int)): - if np.isfinite(_ln_prob) and self.minimum <= val <= self.maximum: - return _ln_prob - return -np.inf - else: - _ln_prob_sub = np.full_like(val, -np.inf) - idx = np.isfinite(_ln_prob) & (val >= self.minimum) & (val <= self.maximum) - _ln_prob_sub[idx] = _ln_prob[idx] - return _ln_prob_sub - - def cdf(self, val): - if isinstance(val, (float, int)): - if val > self.maximum: - return 1. - elif val < self.minimum: - return 0. - else: - return betainc( - self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum) - ) - else: - _cdf = np.nan_to_num(betainc(self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum))) - _cdf[val < self.minimum] = 0. - _cdf[val > self.maximum] = 1. - return _cdf + ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) + ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + return xp.nan_to_num( + betainc( + xp.asarray(self.alpha), + xp.asarray(self.beta), + (val - self.minimum) / (self.maximum - self.minimum) + ) + ) + (val > self.maximum) class Logistic(Prior): @@ -1074,27 +1056,19 @@ def __init__(self, mu, scale, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Logistic prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = self.mu + self.scale * np.log(val / (1. - val)) - else: - rescaled = np.inf * np.ones(val.size) - rescaled[val == 0] = -np.inf - rescaled[(val > 0) & (val < 1)] = self.mu + self.scale\ - * np.log(val[(val > 0) & (val < 1)] / (1. - val[(val > 0) & (val < 1)])) - return rescaled + with np.errstate(divide="ignore"): + val = xp.asarray(val) + return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1105,9 +1079,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1118,11 +1093,13 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return -(val - self.mu) / self.scale -\ - 2. * np.log(1. + np.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) + with np.errstate(over="ignore"): + return -(val - self.mu) / self.scale -\ + 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(xp.asarray(self.scale)) - def cdf(self, val): - return 1. / (1. + np.exp(-(val - self.mu) / self.scale)) + @xp_wrap + def cdf(self, val, *, xp=None): + return 1. / (1. + xp.exp(-(val - self.mu) / self.scale)) class Cauchy(Prior): @@ -1154,24 +1131,18 @@ def __init__(self, alpha, beta, name=None, latex_label=None, unit=None, boundary self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Cauchy prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - rescaled = self.alpha + self.beta * np.tan(np.pi * (val - 0.5)) - if isinstance(val, (float, int)): - if val == 1: - rescaled = np.inf - elif val == 0: - rescaled = -np.inf - else: - rescaled[val == 1] = np.inf - rescaled[val == 0] = -np.inf - return rescaled + rescaled = self.alpha + self.beta * xp.tan(np.pi * (val - 0.5)) + with np.errstate(divide="ignore", invalid="ignore"): + return rescaled - xp.log(val < 1) + xp.log(val > 0) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1184,7 +1155,8 @@ def prob(self, val): """ return 1. / self.beta / np.pi / (1. + ((val - self.alpha) / self.beta) ** 2) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1195,10 +1167,11 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return - np.log(self.beta * np.pi) - np.log(1. + ((val - self.alpha) / self.beta) ** 2) + return - xp.log(xp.asarray(self.beta * np.pi)) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) - def cdf(self, val): - return 0.5 + np.arctan((val - self.alpha) / self.beta) / np.pi + @xp_wrap + def cdf(self, val, *, xp=None): + return 0.5 + xp.arctan((val - self.alpha) / self.beta) / np.pi class Lorentzian(Cauchy): @@ -1235,15 +1208,17 @@ def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary self.k = k self.theta = theta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gamma prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return gammaincinv(self.k, val) * self.theta + return xp.asarray(gammaincinv(self.k, val)) * self.theta - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1254,9 +1229,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1267,28 +1243,16 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = xlogy(self.k - 1, val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - else: - _ln_prob = -np.inf * np.ones(val.size) - idx = (val >= self.minimum) - _ln_prob[idx] = xlogy(self.k - 1, val[idx]) - val[idx] / self.theta\ + with np.errstate(divide="ignore"): + ln_prob = ( + xlogy(xp.asarray(self.k - 1), val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - return _ln_prob - - def cdf(self, val): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = gammainc(self.k, val / self.theta) - else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = gammainc(self.k, val[val >= self.minimum] / self.theta) - return _cdf + ) + xp.log(val >= self.minimum) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=xp.inf) + + @xp_wrap + def cdf(self, val, *, xp=None): + return gammainc(xp.asarray(self.k), xp.maximum(val, xp.asarray(self.minimum)) / self.theta) class ChiSquared(Gamma): @@ -1375,9 +1339,11 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") - self.expr = np.exp(self.r) + xp = array_module((mu, sigma, r)) + self.expr = xp.exp(self.r) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1395,9 +1361,10 @@ def rescale(self, val): `_, 2017. """ inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr - return -self.sigma * np.log(np.maximum(inv, 0)) + return -self.sigma * xp.log(xp.maximum(inv, xp.asarray(0))) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1409,12 +1376,13 @@ def prob(self, val): float: Prior probability of val """ return ( - (np.exp((val - self.mu) / self.sigma) + 1)**-1 - / (self.sigma * np.log1p(self.expr)) + (xp.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * xp.log1p(xp.asarray(self.expr))) * (val >= self.minimum) ) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1425,9 +1393,10 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return np.log(self.prob(val)) + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Evaluate the CDF of the Fermi-Dirac distribution using a slightly modified form of Equation 23 of [1]_. @@ -1449,10 +1418,10 @@ def cdf(self, val): `_, 2017. """ result = ( - (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) - / np.logaddexp(0, self.r) + (xp.logaddexp(xp.asarray(0.0), -xp.asarray(self.r)) - xp.logaddexp(-val / self.sigma, -xp.asarray(self.r))) + / xp.logaddexp(xp.asarray(0.0), xp.asarray(self.r)) ) - return np.clip(result, 0, 1) + return xp.clip(result, 0, 1) class WeightedDiscreteValues(Prior): @@ -1482,20 +1451,21 @@ def __init__( The unit of the parameter. Used for plotting. """ + xp = array_module(values) nvalues = len(values) - values = np.array(values) + values = xp.asarray(values) if values.shape != (nvalues,): raise ValueError( f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}" ) - minimum = np.min(values) + minimum = xp.min(values) # Small delta added to help with MCMC walking - maximum = np.max(values) * (1 + 1e-15) + maximum = xp.max(values) * (1 + 1e-15) super(WeightedDiscreteValues, self).__init__( name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) self.nvalues = nvalues - sorter = np.argsort(values) + sorter = xp.argsort(values) self._values_array = values[sorter] # inititialization of priors from repr only supports @@ -1503,9 +1473,9 @@ def __init__( self.values = self._values_array.tolist() weights = ( - np.array(weights) / np.sum(weights) + xp.asarray(weights) / xp.sum(weights) if weights is not None - else np.ones(self.nvalues) / self.nvalues + else xp.ones(self.nvalues) / self.nvalues ) # check for consistent shape of input if weights.shape != (self.nvalues,): @@ -1516,14 +1486,15 @@ def __init__( ) self._weights_array = weights[sorter] self.weights = self._weights_array.tolist() - self._lnweights_array = np.log(self._weights_array) + self._lnweights_array = xp.log(self._weights_array) # save cdf for rescaling - _cumulative_weights_array = np.cumsum(self._weights_array) + _cumulative_weights_array = xp.cumsum(self._weights_array) # insert 0 for values smaller than minimum - self._cumulative_weights_array = np.insert(_cumulative_weights_array, 0, 0) + self._cumulative_weights_array = xp.insert(_cumulative_weights_array, 0, 0) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the discrete-value prior. @@ -1538,10 +1509,11 @@ def rescale(self, val): ======= Union[float, array_like]: Rescaled probability """ - index = np.searchsorted(self._cumulative_weights_array[1:], val) - return self._values_array[index] + index = xp.searchsorted(xp.asarray(self._cumulative_weights_array[1:]), val) + return xp.asarray(self._values_array)[index] - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """Return the cumulative prior probability of val. Parameters @@ -1552,10 +1524,11 @@ def cdf(self, val): ======= float: cumulative prior probability of val """ - index = np.searchsorted(self._values_array, val, side="right") - return self._cumulative_weights_array[index] + index = xp.searchsorted(xp.asarray(self._values_array), val, side="right") + return xp.asarray(self._cumulative_weights_array)[index] - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1566,13 +1539,18 @@ def prob(self, val): ======= float: Prior probability of val """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - p = np.where(self._values_array[index] == val, self._weights_array[index], 0) + index = xp.searchsorted(xp.asarray(self._values_array), val) + index = xp.clip(index, 0, self.nvalues - 1) + p = xp.where( + xp.asarray(self._values_array[index]) == val, + xp.asarray(self._weights_array[index]), + xp.asarray(0.0), + ) # turn 0d numpy array to scalar return p[()] - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, xp=None): """Return the logarithmic prior probability of val Parameters @@ -1584,12 +1562,14 @@ def ln_prob(self, val): float: """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - lnp = np.where( - self._values_array[index] == val, self._lnweights_array[index], -np.inf + index = xp.searchsorted(xp.asarray(self._values_array), val) + index = xp.clip(index, 0, self.nvalues - 1) + lnp = xp.where( + xp.asarray(self._values_array[index]) == val, + xp.asarray(self._lnweights_array[index]), + -np.inf, ) - # turn 0d numpy array to scalar + # turn 0d array to scalar return lnp[()] @@ -1713,7 +1693,7 @@ def __init__(self, mode, minimum, maximum, name=None, latex_label=None, unit=Non self.rescaled_minimum = self.minimum - (self.minimum == self.mode) * self.scale self.rescaled_maximum = self.maximum + (self.maximum == self.mode) * self.scale - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from standard uniform to a triangular distribution. @@ -1735,7 +1715,7 @@ def rescale(self, val): self.maximum - above_mode ) * (val >= self.fractional_mode) - def prob(self, val): + def prob(self, val, *, xp=None): """ Return the prior probability of val @@ -1762,7 +1742,7 @@ def prob(self, val): ) return 2.0 * (between_minimum_and_mode + between_mode_and_maximum) / self.scale - def cdf(self, val): + def cdf(self, val, *, xp=None): """ Return the prior cumulative probability at val @@ -1789,3 +1769,5 @@ def cdf(self, val): / (self.mode - self.rescaled_minimum) ) ) + + betaln, diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 5ca28de28..137f6df5d 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -2,7 +2,9 @@ import json import os import re +import warnings +import array_api_compat as aac import numpy as np import scipy.stats @@ -12,8 +14,8 @@ decode_bilby_json, logger, get_dict_with_properties, - WrappedInterp1d as interp1d, ) +from ...compat.utils import xp_wrap class Prior(object): @@ -57,6 +59,27 @@ def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, self.boundary = boundary self._is_fixed = False + def __init_subclass__(cls): + for method_name in ["prob", "ln_prob", "rescale", "cdf", "sample"]: + method = getattr(cls, method_name, None) + if method is not None: + from inspect import signature + + sig = signature(method) + if "xp" not in sig.parameters: + warnings.warn( + f"The method {method_name} of the prior class " + f"{cls.__name__} does not accept an 'xp' keyword " + "argument. This may cause some behaviour to fail. " + "Please see the bilby documentation for more " + "information: https://bilby-dev.github.io/bilby/" + "array_api.html" + f" {sig}", + DeprecationWarning, + stacklevel=2, + ) + setattr(cls, method_name, xp_wrap(method, no_xp=True)) + def __call__(self): """Overrides the __call__ special method. Calls the sample method. @@ -106,7 +129,7 @@ def __eq__(self, other): for key in this_dict: if key == "least_recently_sampled": continue - if isinstance(this_dict[key], np.ndarray): + if aac.is_array_api_obj(this_dict[key]): if not np.array_equal(this_dict[key], other_dict[key]): return False elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): @@ -116,7 +139,7 @@ def __eq__(self, other): return False return True - def sample(self, size=None): + def sample(self, size=None, *, xp=np): """Draw a sample from the prior Parameters @@ -131,10 +154,12 @@ def sample(self, size=None): """ from ..utils import random - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size)) + self.least_recently_sampled = self.rescale( + xp.asarray(random.rng.uniform(0, 1, size)) + ) return self.least_recently_sampled - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -152,7 +177,7 @@ def rescale(self, val): """ return None - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val, this should be overwritten Parameters @@ -166,24 +191,22 @@ def prob(self, val): """ return np.nan - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Generic method to calculate CDF, can be overwritten in subclass """ from scipy.integrate import cumulative_trapezoid if np.any(np.isinf([self.minimum, self.maximum])): raise ValueError( "Unable to use the generic CDF calculation for priors with" "infinite support") - x = np.linspace(self.minimum, self.maximum, 1000) - pdf = self.prob(x) + x = xp.linspace(self.minimum, self.maximum, 1000) + pdf = self.prob(x, xp=xp) cdf = cumulative_trapezoid(pdf, x, initial=0) - interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, - fill_value=(0, 1)) - output = interp(val) - if isinstance(val, (int, float)): - output = float(output) - return output - - def ln_prob(self, val): + output = xp.interp(val, x, cdf / cdf[-1], left=0, right=1) + return output[()] + + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the prior ln probability of val, this should be overwritten Parameters @@ -196,7 +219,7 @@ def ln_prob(self, val): """ with np.errstate(divide='ignore'): - return np.log(self.prob(val)) + return xp.log(self.prob(val, xp=xp)) def is_in_prior_range(self, val): """Returns True if val is in the prior boundaries, zero otherwise @@ -473,7 +496,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, latex_label=latex_label, unit=unit) self._is_fixed = True - def prob(self, val): + def prob(self, val, *, xp=None): return (val > self.minimum) & (val < self.maximum) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index ad142c2a9..f42f83239 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -1,9 +1,12 @@ +import numpy as np + from .base import Prior, PriorException from .interpolated import Interped from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac from ..utils import infer_args_from_method, infer_parameters_from_function +from ...compat.utils import xp_wrap def conditional_prior_factory(prior_class): @@ -59,7 +62,7 @@ def condition_func(reference_params, y): self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) - def sample(self, size=None, **required_variables): + def sample(self, size=None, *, xp=np, **required_variables): """Draw a sample from the prior Parameters @@ -76,10 +79,15 @@ def sample(self, size=None, **required_variables): """ from ..utils import random - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size), **required_variables) + self.least_recently_sampled = self.rescale( + xp.asarray(random.rng.uniform(0, 1, size)), + xp=xp, + **required_variables, + ) return self.least_recently_sampled - def rescale(self, val, **required_variables): + @xp_wrap + def rescale(self, val, *, xp=None, **required_variables): """ 'Rescale' a sample from the unit line element to the prior. @@ -93,9 +101,10 @@ def rescale(self, val, **required_variables): """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).rescale(val) + return super(ConditionalPrior, self).rescale(val, xp=xp) - def prob(self, val, **required_variables): + @xp_wrap + def prob(self, val, *, xp=None, **required_variables): """Return the prior probability of val. Parameters @@ -111,9 +120,10 @@ def prob(self, val, **required_variables): float: Prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).prob(val) + return super(ConditionalPrior, self).prob(val, xp=xp) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=None, **required_variables): """Return the natural log prior probability of val. Parameters @@ -129,9 +139,10 @@ def ln_prob(self, val, **required_variables): float: Natural log prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).ln_prob(val) + return super(ConditionalPrior, self).ln_prob(val, xp=xp) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=None, **required_variables): """Return the cdf of val. Parameters @@ -147,7 +158,7 @@ def cdf(self, val, **required_variables): float: CDF of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).cdf(val) + return super(ConditionalPrior, self).cdf(val, xp=xp) def update_conditions(self, **required_variables): """ @@ -164,6 +175,7 @@ class depending on the required variables it depends on. self.reference_params will be used. """ + required_variables.pop("xp", None) if sorted(list(required_variables)) == sorted(self.required_variables): parameters = self.condition_func(self.reference_params.copy(), **required_variables) for key, value in parameters.items(): diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 3ac54622e..c3e61d569 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -5,6 +5,7 @@ from io import open as ioopen from warnings import warn +import array_api_compat as aac import numpy as np from .analytical import DeltaFunction @@ -16,6 +17,7 @@ BilbyJsonEncoder, decode_bilby_json, ) +from ...compat.utils import array_module, xp_wrap class PriorDict(dict): @@ -54,12 +56,16 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): else: self.conversion_function = self.default_conversion_function - def evaluate_constraints(self, sample): + def __hash__(self): + return hash(str(self)) + + @xp_wrap + def evaluate_constraints(self, sample, *, xp=None): out_sample = self.conversion_function(sample) try: - prob = np.ones_like(next(iter(out_sample.values()))) + prob = xp.ones_like(next(iter(out_sample.values())), dtype=bool) except TypeError: - prob = np.ones_like(out_sample) + prob = xp.ones_like(out_sample, dtype=bool) for key in self: if isinstance(self[key], Constraint) and key in out_sample: prob *= self[key].prob(out_sample[key]) @@ -349,7 +355,7 @@ def fill_priors(self, likelihood=None, default_priors_file=None): for key in self: self.test_redundancy(key) - def sample(self, size=None): + def sample(self, size=None, *, xp=np): """Draw samples from the prior set Parameters @@ -361,9 +367,9 @@ def sample(self, size=None): ======= dict: Dictionary of the samples """ - return self.sample_subset_constrained(keys=list(self.keys()), size=size) + return self.sample_subset_constrained(keys=list(self.keys()), size=size, xp=xp) - def sample_subset_constrained_as_array(self, keys=iter([]), size=None): + def sample_subset_constrained_as_array(self, keys=iter([]), size=None, *, xp=np): """Return an array of samples Parameters @@ -378,12 +384,12 @@ def sample_subset_constrained_as_array(self, keys=iter([]), size=None): array: array_like An array of shape (len(key), size) of the samples (ordered by keys) """ - samples_dict = self.sample_subset_constrained(keys=keys, size=size) - samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()} + samples_dict = self.sample_subset_constrained(keys=keys, size=size, xp=xp) + samples_dict = {key: xp.atleast_1d(val) for key, val in samples_dict.items()} samples_list = [samples_dict[key] for key in keys] - return np.array(samples_list) + return xp.stack(samples_list) - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, xp=np): """Draw samples from the prior set for parameters which are not a DeltaFunction Parameters @@ -403,7 +409,7 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Constraint): continue elif isinstance(self[key], Prior): - samples[key] = self[key].sample(size=size) + samples[key] = self[key].sample(size=size, xp=xp) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -426,7 +432,7 @@ def fixed_keys(self): def constraint_keys(self): return [k for k, p in self.items() if isinstance(p, Constraint)] - def sample_subset_constrained(self, keys=iter([]), size=None): + def sample_subset_constrained(self, keys=iter([]), size=None, *, xp=np): """ Sample a subset of priors while ensuring constraints are satisfied. @@ -442,7 +448,7 @@ def sample_subset_constrained(self, keys=iter([]), size=None): dict: Dictionary of valid samples. """ if not any(isinstance(self[key], Constraint) for key in self): - return self.sample_subset(keys=keys, size=size) + return self.sample_subset(keys=keys, size=size, xp=xp) efficiency_warning_was_issued = False @@ -458,10 +464,10 @@ def check_efficiency(n_tested, n_valid): n_tested_samples, n_valid_samples = 0, 0 if size is None or size == 1: while True: - sample = self.sample_subset(keys=keys, size=size) + sample = self.sample_subset(keys=keys, size=size, xp=xp) is_valid = self.evaluate_constraints(sample) n_tested_samples += 1 - n_valid_samples += int(is_valid) + n_valid_samples += int(is_valid.item()) check_efficiency(n_tested_samples, n_valid_samples) if is_valid: return sample @@ -470,20 +476,22 @@ def check_efficiency(n_tested, n_valid): for key in keys.copy(): if isinstance(self[key], Constraint): del keys[keys.index(key)] - all_samples = {key: np.array([]) for key in keys} + all_samples = {key: xp.asarray([]) for key in keys} _first_key = list(all_samples.keys())[0] while len(all_samples[_first_key]) < needed: - samples = self.sample_subset(keys=keys, size=needed) - keep = np.array(self.evaluate_constraints(samples), dtype=bool) + samples = self.sample_subset(keys=keys, size=needed, xp=xp) + keep = self.evaluate_constraints(samples, xp=xp) for key in keys: - all_samples[key] = np.hstack( + all_samples[key] = xp.hstack( [all_samples[key], samples[key][keep].flatten()] ) n_tested_samples += needed - n_valid_samples += np.sum(keep) + n_valid_samples += int(xp.sum(keep)) check_efficiency(n_tested_samples, n_valid_samples) + if not isinstance(size, tuple): + size = (size,) all_samples = { - key: np.reshape(all_samples[key][:needed], size) for key in keys + key: xp.reshape(all_samples[key][:needed], size) for key in keys } return all_samples @@ -508,22 +516,23 @@ def normalize_constraint_factor( self._cached_normalizations[keys] = factor_rounded return factor_rounded - def _estimate_normalization(self, keys, min_accept, sampling_chunk): - samples = self.sample_subset(keys=keys, size=sampling_chunk) + def _estimate_normalization(self, keys, min_accept, sampling_chunk, *, xp=np): + samples = self.sample_subset(keys=keys, size=sampling_chunk, xp=xp) keep = np.atleast_1d(self.evaluate_constraints(samples)) if len(keep) == 1: self._cached_normalizations[keys] = 1 return 1 all_samples = {key: np.array([]) for key in keys} while np.count_nonzero(keep) < min_accept: - samples = self.sample_subset(keys=keys, size=sampling_chunk) + samples = self.sample_subset(keys=keys, size=sampling_chunk, xp=xp) for key in samples: all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()]) keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + @xp_wrap + def prob(self, sample, *, xp=None, **kwargs): """ Parameters @@ -538,29 +547,28 @@ def prob(self, sample, **kwargs): float: Joint probability of all individual sample probabilities """ - prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) + prob = xp.prod(xp.stack([self[key].prob(sample[key], xp=xp) for key in sample]), **kwargs) - return self.check_prob(sample, prob) + return self.check_prob(sample, prob, xp=xp) - def check_prob(self, sample, prob): + @xp_wrap + def check_prob(self, sample, prob, *, xp=None): ratio = self.normalize_constraint_factor(tuple(sample.keys())) - if np.all(prob == 0.0): + if not aac.is_jax_namespace(xp) and xp.all(prob == 0.0): return prob * ratio else: if isinstance(prob, float): - if self.evaluate_constraints(sample): + if self.evaluate_constraints(sample, xp=xp): return prob * ratio else: return 0.0 else: - constrained_prob = np.zeros_like(prob) - in_bounds = np.isfinite(prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.array(self.evaluate_constraints(subsample), dtype=bool) - constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio + keep = self.evaluate_constraints(sample, xp=xp) + constrained_prob = xp.where(keep, prob * ratio, 0.0) return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + @xp_wrap + def ln_prob(self, sample, axis=None, normalized=True, *, xp=None): """ Parameters @@ -579,32 +587,30 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + ln_prob = xp.sum(xp.stack([self[key].ln_prob(sample[key], xp=xp) for key in sample]), axis=axis) + return self.check_ln_prob(sample, ln_prob, normalized=normalized, xp=xp) - def check_ln_prob(self, sample, ln_prob, normalized=True): + @xp_wrap + def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): if normalized: ratio = self.normalize_constraint_factor(tuple(sample.keys())) else: ratio = 1 - if np.all(np.isinf(ln_prob)): + if not aac.is_jax_namespace(xp) and xp.all(xp.isfinite(ln_prob)): return ln_prob else: if isinstance(ln_prob, float): - if np.all(self.evaluate_constraints(sample)): - return ln_prob + np.log(ratio) + if xp.all(self.evaluate_constraints(sample, xp=xp)): + return ln_prob + xp.log(ratio) else: return -np.inf else: - constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - in_bounds = np.isfinite(ln_prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) - constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) + keep = self.evaluate_constraints(sample, xp=xp) + constrained_ln_prob = xp.where(keep, ln_prob + xp.log(ratio), -xp.inf) return constrained_ln_prob - def cdf(self, sample): + @xp_wrap + def cdf(self, sample, *, xp=None): """Evaluate the cumulative distribution function at the provided points Parameters @@ -618,10 +624,10 @@ def cdf(self, sample): """ return sample.__class__( - {key: self[key].cdf(sample) for key, sample in sample.items()} + {key: self[key].cdf(sample, xp=xp) for key, sample in sample.items()} ) - def rescale(self, keys, theta): + def rescale(self, keys, theta, *, xp=None): """Rescale samples from unit cube to prior Parameters @@ -635,9 +641,12 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - return list( - [self[key].rescale(sample) for key, sample in zip(keys, theta)] - ) + if isinstance(theta, {}.values().__class__): + theta = list(theta) + if xp is None: + xp = array_module(theta) + + return xp.asarray([self[key].rescale(sample, xp=xp) for key, sample in zip(keys, theta)]) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -737,7 +746,7 @@ def _check_conditions_resolved(self, key, sampled_keys): conditions_resolved = False return conditions_resolved - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, xp=np): self.convert_floats_to_delta_functions() add_delta_keys = [ key @@ -757,7 +766,9 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Prior): try: samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + size=size, + xp=xp, + **subset_dict.get_required_variables(key), ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) @@ -768,7 +779,10 @@ def sample_subset(self, keys=iter([]), size=None): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = subset_dict[key].sample( + **rvars, + xp=xp, + ) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -790,7 +804,8 @@ def get_required_variables(self, key): for k in getattr(self[key], "required_variables", []) } - def prob(self, sample, **kwargs): + @xp_wrap + def prob(self, sample, *, xp=None, **kwargs): """ Parameters @@ -806,12 +821,12 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ - self[key].prob(sample[key], **self.get_required_variables(key)) + res = xp.asarray([ + self[key].prob(sample[key], **self.get_required_variables(key), xp=xp) for key in sample - ] - prob = np.prod(res, **kwargs) - return self.check_prob(sample, prob) + ]) + prob = xp.prod(res, **kwargs) + return self.check_prob(sample, prob, xp=xp) def ln_prob(self, sample, axis=None, normalized=True): """ @@ -832,18 +847,21 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ + xp = array_module(sample.values()) + res = xp.asarray([ self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample - ] - ln_prob = np.sum(res, axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) - - def cdf(self, sample): + ]) + ln_prob = xp.sum(res, axis=axis) + return ln_prob + # return self.check_ln_prob(sample, ln_prob, + # normalized=normalized) + + @xp_wrap + def cdf(self, sample, *, xp=None): self._prepare_evaluation(*zip(*sample.items())) res = { - key: self[key].cdf(sample[key], **self.get_required_variables(key)) + key: self[key].cdf(sample[key], **self.get_required_variables(key), xp=xp) for key in sample } return sample.__class__(res) @@ -862,8 +880,11 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ + if isinstance(theta, {}.values().__class__): + theta = list(theta) + xp = array_module(theta) + keys = list(keys) - theta = list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() @@ -880,30 +901,14 @@ def rescale(self, keys, theta): elif isinstance(self[key], JointPrior): joint[self[key].dist.distname].append(key) for names in joint.values(): - # this is needed to unpack how joint prior rescaling works - # as an example of a joint prior over {a, b, c, d} we might - # get the following based on the order within the joint prior - # {a: [], b: [], c: [1, 2, 3, 4], d: []} - # -> [1, 2, 3, 4] - # -> {a: 1, b: 2, c: 3, d: 4} - values = list() for key in names: - values = np.concatenate([values, result[key]]) - for key, value in zip(names, values): - result[key] = value - - def safe_flatten(value): - """ - this is gross but can be removed whenever we switch to returning - arrays, flatten converts 0-d arrays to 1-d so has to be special - cased - """ - if isinstance(value, (float, int, np.int64)): - return value - else: - return result[key].flatten() + if result[key] is None: + continue + for subkey, val in zip(self[key].dist.names, result[key]): + self[subkey].least_recently_sampled = val + result[subkey] = val - return [safe_flatten(result[key]) for key in keys] + return xp.asarray([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 5fbf8f9c1..1983877d7 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -2,7 +2,9 @@ from scipy.integrate import trapezoid from .base import Prior -from ..utils import logger, WrappedInterp1d as interp1d +from ..utils import logger +from ..utils.calculus import interp1d +from ...compat.utils import xp_wrap class Interped(Prior): @@ -64,7 +66,8 @@ def __eq__(self, other): return True return False - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -75,18 +78,20 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return self.probability_density(val) + return self.probability_density(val)[()] - def cdf(self, val): - return self.cumulative_distribution(val) + @xp_wrap + def cdf(self, val, *, xp=None): + return self.cumulative_distribution(val)[()] - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. This maps to the inverse CDF. This is done using interpolation. """ - return self.inverse_cumulative_distribution(val) + return self.inverse_cumulative_distribution(val)[()] @property def minimum(self): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..238c0d791 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -1,5 +1,6 @@ import re +import array_api_extra as xpx import numpy as np import scipy.stats from scipy.special import erfinv @@ -7,6 +8,7 @@ from .base import Prior, PriorException from ..utils import logger, infer_args_from_method, get_dict_with_properties from ..utils import random +from ...compat.utils import xp_wrap class BaseJointPriorDist(object): @@ -172,13 +174,14 @@ def _split_repr(cls, string): kwargs[key.strip()] = arg return kwargs - def prob(self, samp): + @xp_wrap + def prob(self, samp, *, xp=None): """ Get the probability of a sample. For bounded priors the probability will not be properly normalised. """ - return np.exp(self.ln_prob(samp)) + return xp.exp(self.ln_prob(samp, xp=xp)) def _check_samp(self, value): """ @@ -216,7 +219,8 @@ def _check_samp(self, value): break return samp, outbounds - def ln_prob(self, value): + @xp_wrap + def ln_prob(self, value, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. @@ -230,14 +234,12 @@ def ln_prob(self, value): """ samp, outbounds = self._check_samp(value) - lnprob = -np.inf * np.ones(samp.shape[0]) - lnprob = self._ln_prob(samp, lnprob, outbounds) - if samp.shape[0] == 1: - return lnprob[0] - else: - return lnprob + lnprob = -np.inf * xp.ones(samp.shape[0]) + lnprob = self._ln_prob(samp, lnprob, outbounds, xp=xp) + return lnprob[()] - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. **this method needs overwritten by child class** @@ -261,7 +263,7 @@ def _ln_prob(self, samp, lnprob, outbounds): """ return lnprob - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, xp=np, **kwargs): """ Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten @@ -273,14 +275,11 @@ def sample(self, size=1, **kwargs): if size is None: size = 1 - samps = self._sample(size=size, **kwargs) + samps = self._sample(size=size, xp=xp, **kwargs) for i, name in enumerate(self.names): - if size == 1: - self.current_sample[name] = samps[:, i].flatten()[0] - else: - self.current_sample[name] = samps[:, i].flatten() + self.current_sample[name] = samps[:, i].flatten()[()] - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): """ Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**) @@ -289,13 +288,14 @@ def _sample(self, size, **kwargs): size: int number of samples to generate, defaults to 1 """ - samps = np.zeros((size, len(self))) + samps = xp.zeros((size, len(self))) """ Here is where the subclass where overwrite sampling method """ return samps - def rescale(self, value, **kwargs): + @xp_wrap + def rescale(self, value, *, xp=None, **kwargs): """ Rescale from a unit hypercube to JointPriorDist. Note that no bounds are applied in the rescale function. (child classes need to @@ -317,7 +317,7 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + samp = xp.asarray(value) if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,7 +327,9 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) - return np.squeeze(samp) + if samp.shape[0] == 1: + samp = xp.squeeze(samp, axis=0) + return samp def _rescale(self, samp, **kwargs): """ @@ -611,7 +613,8 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1]) ) - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=None, **kwargs): try: mode = kwargs["mode"] except KeyError: @@ -626,12 +629,12 @@ def _rescale(self, samp, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] + samp = xp.asarray(self.mus[mode]) + xp.asarray(self.sigmas[mode]) * xp.einsum( + "ij,kj->ik", samp * self.sqeigvalues[mode], xp.asarray(self.eigvectors[mode]) ) return samp - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): try: mode = kwargs["mode"] except KeyError: @@ -673,18 +676,21 @@ def _sample(self, size, **kwargs): if not outbound: inbound = True - return samps + return xp.asarray(samps) - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): for j in range(samp.shape[0]): # loop over the modes and sum the probabilities for i in range(self.nmodes): # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + lnprob = xpx.at(lnprob, j).set( + xp.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - xp.asarray(self.logprodsigmas[i])) + ) # set out-of-bounds values to -inf - lnprob[outbounds] = -np.inf + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): @@ -778,7 +784,8 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) - def rescale(self, val, **kwargs): + @xp_wrap + def rescale(self, val, *, xp=None, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -793,18 +800,21 @@ def rescale(self, val, **kwargs): float: A sample from the prior parameter. """ - self.dist.rescale_parameters[self.name] = val if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T + # print(self.dist.rescale_parameters) + values = xp.stack([ + xp.asarray(val) for val in self.dist.rescale_parameters.values() + ]).T + # values = xp.asarray(list(self.dist.rescale_parameters.values())).T samples = self.dist.rescale(values, **kwargs) self.dist.reset_rescale() return samples else: return [] # return empty list - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, xp=np, **kwargs): """ Draw a sample from the prior. @@ -829,7 +839,7 @@ def sample(self, size=1, **kwargs): if len(self.dist.current_sample) == 0: # generate a sample - self.dist.sample(size=size, **kwargs) + self.dist.sample(size=size, xp=xp, **kwargs) sample = self.dist.current_sample[self.name] @@ -842,7 +852,8 @@ def sample(self, size=1, **kwargs): self.least_recently_sampled = sample return sample - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """ Return the natural logarithm of the prior probability. Note that this will not be correctly normalised if there are bounds on the @@ -864,25 +875,16 @@ def ln_prob(self, val): values = list(self.dist.requested_parameters.values()) # check for the same number of values for each parameter - for i in range(len(self.dist) - 1): - if isinstance(values[i], (list, np.ndarray)) or isinstance( - values[i + 1], (list, np.ndarray) - ): - if isinstance(values[i], (list, np.ndarray)) and isinstance( - values[i + 1], (list, np.ndarray) - ): - if len(values[i]) != len(values[i + 1]): - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) - else: - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) + shapes = set() + for v in values: + shapes.add(xp.asarray(v).shape) + if len(shapes) > 1: + raise ValueError( + "Each parameter must have the same " + "number of requested values." + ) - lnp = self.dist.ln_prob(np.asarray(values).T) + lnp = self.dist.ln_prob(xp.stack(values).T) # reset the requested parameters self.dist.reset_request() @@ -901,9 +903,10 @@ def ln_prob(self, val): if len(val) == 1: return 0.0 else: - return np.zeros_like(val) + return xp.zeros_like(val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -917,7 +920,7 @@ def prob(self, val): the p value for the prior at given sample """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) class MultivariateGaussian(JointPrior): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 6910be608..2ac310f55 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -1,8 +1,8 @@ -from numbers import Number import numpy as np from .base import Prior from ..utils import logger +from ...compat.utils import xp_wrap class SlabSpikePrior(Prior): @@ -72,7 +72,8 @@ def slab_fraction(self): def _find_inverse_cdf_fraction_before_spike(self): return float(self.slab.cdf(self.spike_location)) * self.slab_fraction - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -85,28 +86,22 @@ def rescale(self, val): ======= array_like: Associated prior value with input value. """ - original_is_number = isinstance(val, Number) - val = np.atleast_1d(val) - lower_indices = val < self.inverse_cdf_below_spike - intermediate_indices = np.logical_and( - self.inverse_cdf_below_spike <= val, - val <= (self.inverse_cdf_below_spike + self.spike_height)) - higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) - - res = np.zeros(len(val)) - res[lower_indices] = self._contracted_rescale(val[lower_indices]) - res[intermediate_indices] = self.spike_location - res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + higher_indices = val >= (self.inverse_cdf_below_spike + self.spike_height) + + slab_scaled = self._contracted_rescale( + val - self.spike_height * higher_indices, xp=xp + ) + + res = xp.where( + lower_indices | higher_indices, + slab_scaled, + xp.asarray(self.spike_location), + ) return res - def _contracted_rescale(self, val): + @xp_wrap + def _contracted_rescale(self, val, *, xp=None): """ Contracted version of the rescale function that implements the `rescale` function on the pure slab part of the prior. @@ -120,9 +115,10 @@ def _contracted_rescale(self, val): ======= array_like: Associated prior value with input value. """ - return self.slab.rescale(val / self.slab_fraction) + return self.slab.rescale(val / self.slab_fraction, xp=xp) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Returns np.inf for the spike location @@ -134,19 +130,13 @@ def prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(invalid="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Returns np.inf for the spike location @@ -158,19 +148,13 @@ def ln_prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(divide="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Return the CDF of the prior. This calls to the slab CDF and adds a discrete step at the spike location. @@ -184,6 +168,6 @@ def cdf(self, val): array_like: CDF value of val """ - res = self.slab.cdf(val) * self.slab_fraction - res += self.spike_height * (val > self.spike_location) + res = self.slab.cdf(val, xp=xp) * self.slab_fraction + res += (val > self.spike_location) * self.spike_height return res diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index e10ce6111..f97ee1a9b 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,10 +1,12 @@ import math +import array_api_compat as aac import numpy as np -from scipy.interpolate import interp1d, RectBivariateSpline -from scipy.special import logsumexp +from scipy.interpolate import RectBivariateSpline, interp1d as _interp1d from .log import logger +from ...compat.patches import logsumexp +from ...compat.utils import array_module, xp_wrap def derivatives( @@ -152,7 +154,8 @@ def derivatives( return grads -def logtrapzexp(lnf, dx): +@xp_wrap +def logtrapzexp(lnf, dx, *, xp=np): """ Perform trapezium rule integration for the logarithm of a function on a grid. @@ -171,22 +174,45 @@ def logtrapzexp(lnf, dx): lnfdx1 = lnf[:-1] lnfdx2 = lnf[1:] - if isinstance(dx, (int, float)): - C = np.log(dx / 2.0) - elif isinstance(dx, (list, np.ndarray)): - if len(dx) != len(lnf) - 1: - raise ValueError( - "Step size array must have length one less than the function length" - ) - lndx = np.log(dx) - lnfdx1 = lnfdx1.copy() + lndx - lnfdx2 = lnfdx2.copy() + lndx - C = -np.log(2.0) - else: - raise TypeError("Step size must be a single value or array-like") + try: + dx = xp.asarray(dx) + except TypeError: + raise TypeError(f"Step size dx={dx} could not be converted to an array") + + if dx.ndim > 0 and len(dx) != len(lnf) - 1: + raise ValueError( + "Step size array must have length one less than the function length" + ) + lnfdx1 = lnfdx1 + xp.log(dx) + lnfdx2 = lnfdx2 + xp.log(dx) + + return logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) - np.log(2) - return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)]) + +class interp1d(_interp1d): + + def __call__(self, x): + from array_api_compat import is_numpy_namespace + + xp = array_module(x) + if is_numpy_namespace(xp): + return super().__call__(x) + else: + return self._call_alt(x, xp=xp) + + def _call_alt(self, x, *, xp=np): + if isinstance(self.fill_value, tuple): + left, right = self.fill_value + else: + left = right = self.fill_value + return xp.interp( + x, + xp.asarray(self.x), + xp.asarray(self.y), + left=left, + right=right, + ) class BoundedRectBivariateSpline(RectBivariateSpline): @@ -202,9 +228,23 @@ def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): if self.y_max is None: self.y_max = max(y) self.fill_value = fill_value + self.x = x + self.y = y + self.z = z super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s) def __call__(self, x, y, dx=0, dy=0, grid=False): + xp = array_module([x, y]) + if aac.is_numpy_namespace(xp): + return self._call_scipy(x, y, dx=dx, dy=dy, grid=grid) + elif aac.is_jax_namespace(xp): + return self._call_jax(x, y) + else: + raise NotImplementedError( + f"BoundedRectBivariateSpline not implemented for {xp.__name__} backend" + ) + + def _call_scipy(self, x, y, dx=0, dy=0, grid=False): result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) @@ -218,6 +258,20 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): else: return result + def _call_jax(self, x, y): + import jax.numpy as jnp + from interpax import interp2d + + return interp2d( + jnp.asarray(x), + jnp.asarray(y), + jnp.asarray(self.x), + jnp.asarray(self.y), + jnp.asarray(self.z), + extrap=self.fill_value if self.fill_value is not None else False, + method="cubic2", + ) + class WrappedInterp1d(interp1d): """ diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 8299d6816..f4c9bc4e8 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -8,6 +8,7 @@ from pathlib import Path from datetime import timedelta +import array_api_compat as aac import numpy as np import pandas as pd @@ -59,8 +60,12 @@ def default(self, obj): return encode_astropy_unit(obj) except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") - if isinstance(obj, np.ndarray): - return {"__array__": True, "content": obj.tolist()} + if aac.is_array_api_obj(obj): + return { + "__array__": True, + "__array_namespace__": aac.get_namespace(obj).__name__, + "content": obj.tolist(), + } if isinstance(obj, complex): return {"__complex__": True, "real": obj.real, "imag": obj.imag} if isinstance(obj, pd.DataFrame): @@ -320,7 +325,9 @@ def decode_bilby_json(dct): if dct.get("__astropy_unit__", False): return decode_astropy_unit(dct) if dct.get("__array__", False): - return np.asarray(dct["content"]) + namespace = dct.get("__array_namespace__", "numpy") + xp = import_module(namespace) + return xp.asarray(dct["content"]) if dct.get("__complex__", False): return complex(dct["real"], dct["imag"]) if dct.get("__dataframe__", False): @@ -438,6 +445,10 @@ def encode_for_hdf5(key, item): if item.dtype.kind == 'U': logger.debug(f'converting dtype {item.dtype} for hdf5') item = np.array(item, dtype='S') + elif aac.is_array_api_obj(item): + # temporarily dump all arrays as numpy arrays, we should figure ou + # how to properly deserialize them + item = np.asarray(item) if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): output = item elif isinstance(item, np.random.Generator): diff --git a/bilby/core/utils/samples.py b/bilby/core/utils/samples.py index a075d6dcd..93fdac0ac 100644 --- a/bilby/core/utils/samples.py +++ b/bilby/core/utils/samples.py @@ -1,3 +1,4 @@ +import array_api_extra as xpx import numpy as np from scipy.special import logsumexp @@ -135,7 +136,7 @@ def reflect(u): u: array-like The input array, modified in place. """ - idxs_even = np.mod(u, 2) < 1 - u[idxs_even] = np.mod(u[idxs_even], 1) - u[~idxs_even] = 1 - np.mod(u[~idxs_even], 1) + idxs_even = (u % 2) < 1 + u = xpx.at(u)[idxs_even].set(u[idxs_even] % 1) + u = xpx.at(u)[~idxs_even].set(1 - (u[~idxs_even] % 1)) return u diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index 63daebd6e..4fa20b51a 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -1,4 +1,5 @@ import numpy as np +from ...compat.utils import array_module _TOL = 14 @@ -97,11 +98,14 @@ def create_time_series(sampling_frequency, duration, starting_time=0.): float: An equidistant time series given the parameters """ + xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) number_of_samples = int(duration * sampling_frequency) - return np.linspace(start=starting_time, - stop=duration + starting_time - 1 / sampling_frequency, - num=number_of_samples) + return xp.linspace( + starting_time, + duration + starting_time - 1 / sampling_frequency, + num=number_of_samples, + ) def create_frequency_series(sampling_frequency, duration): @@ -117,13 +121,12 @@ def create_frequency_series(sampling_frequency, duration): array_like: frequency series """ + xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) - number_of_samples = int(np.round(duration * sampling_frequency)) - number_of_frequencies = int(np.round(number_of_samples / 2) + 1) + number_of_samples = xp.round(duration * sampling_frequency) + number_of_frequencies = int(xp.round(number_of_samples / 2) + 1) - return np.linspace(start=0, - stop=sampling_frequency / 2, - num=number_of_frequencies) + return xp.linspace(0, sampling_frequency / 2, num=number_of_frequencies) def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): @@ -139,7 +142,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > 10**(-_TOL): + if abs(num % 1) > 10**(-_TOL): raise IllegalDurationAndSamplingFrequencyException( '\nYour sampling frequency and duration must multiply to a number' 'up to (tol = {}) decimals close to an integer number. ' @@ -206,10 +209,11 @@ def nfft(time_domain_strain, sampling_frequency): strain / Hz, and the associated frequency_array. """ - frequency_domain_strain = np.fft.rfft(time_domain_strain) + xp = array_module(time_domain_strain) + frequency_domain_strain = xp.fft.rfft(time_domain_strain) frequency_domain_strain /= sampling_frequency - frequency_array = np.linspace( + frequency_array = xp.linspace( 0, sampling_frequency / 2, len(frequency_domain_strain)) return frequency_domain_strain, frequency_array @@ -231,7 +235,8 @@ def infft(frequency_domain_strain, sampling_frequency): time_domain_strain: array_like An array of the time domain strain """ - time_domain_strain_norm = np.fft.irfft(frequency_domain_strain) + xp = array_module(frequency_domain_strain) + time_domain_strain_norm = xp.fft.irfft(frequency_domain_strain) time_domain_strain = time_domain_strain_norm * sampling_frequency return time_domain_strain diff --git a/bilby/gw/__init__.py b/bilby/gw/__init__.py index b5115766b..cd09bc6f6 100644 --- a/bilby/gw/__init__.py +++ b/bilby/gw/__init__.py @@ -3,4 +3,5 @@ from .waveform_generator import WaveformGenerator, LALCBCWaveformGenerator from .likelihood import GravitationalWaveTransient from .detector import calibration +from . import compat diff --git a/bilby/gw/compat/__init__.py b/bilby/gw/compat/__init__.py new file mode 100644 index 000000000..36f2566c4 --- /dev/null +++ b/bilby/gw/compat/__init__.py @@ -0,0 +1,15 @@ +try: + from .jax import n_leap_seconds +except ModuleNotFoundError: + pass + + +try: + from .cython import gps_time_to_utc +except ModuleNotFoundError: + pass + +try: + from .torch import n_leap_seconds +except ModuleNotFoundError: + pass \ No newline at end of file diff --git a/bilby/gw/compat/cython.py b/bilby/gw/compat/cython.py new file mode 100644 index 000000000..9d0a69af0 --- /dev/null +++ b/bilby/gw/compat/cython.py @@ -0,0 +1,66 @@ +import numpy as np +from bilby_cython import time as _time, geometry as _geometry +from plum import dispatch + +from ...compat.types import Real, ArrayLike + + +@dispatch(precedence=1) +def gps_time_to_utc(gps_time: Real): + return _time.gps_time_to_utc(gps_time) + + +@dispatch(precedence=1) +def greenwich_mean_sidereal_time(gps_time: Real | ArrayLike): + return _time.greenwich_mean_sidereal_time(gps_time) + + +@dispatch(precedence=1) +def greenwich_sidereal_time(gps_time: Real, equation_of_equinoxes: Real): + return _time.greenwich_sidereal_time(gps_time, equation_of_equinoxes) + + +@dispatch(precedence=1) +def n_leap_seconds(gps_time: Real): + return _time.n_leap_seconds(gps_time) + + +@dispatch(precedence=1) +def utc_to_julian_day(utc_time: Real): + return _time.utc_to_julian_day(utc_time) + + +@dispatch(precedence=1) +def calculate_arm(arm_tilt: Real, arm_azimuth: Real, longitude: Real, latitude: Real): + return _geometry.calculate_arm(arm_tilt, arm_azimuth, longitude, latitude) + + +@dispatch(precedence=1) +def detector_tensor(x: ArrayLike, y: ArrayLike): + return _geometry.detector_tensor(x, y) + + +@dispatch(precedence=1) +def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: str): + return _geometry.get_polarization_tensor(ra, dec, time, psi, mode) + + +@dispatch(precedence=1) +def rotation_matrix_from_delta(delta: ArrayLike): + return _geometry.rotation_matrix_from_delta(delta) + + +@dispatch(precedence=1) +def time_delay_geocentric(detector1: ArrayLike, detector2: ArrayLike, ra, dec, time): + return _geometry.time_delay_geocentric(detector1, detector2, ra, dec, time) + + +@dispatch(precedence=1) +def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: Real | ArrayLike): + return _geometry.time_delay_from_geocenter(detector1, ra, dec, time) + + +@dispatch(precedence=1) +def zenith_azimuth_to_theta_phi(zenith: Real, azimuth: Real, delta_x: np.ndarray): + theta, phi = _geometry.zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + return theta, phi % (2 * np.pi) diff --git a/bilby/gw/compat/jax.py b/bilby/gw/compat/jax.py new file mode 100644 index 000000000..99277e30a --- /dev/null +++ b/bilby/gw/compat/jax.py @@ -0,0 +1,20 @@ +import jax.numpy as jnp +from jax import Array +from plum import dispatch + +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = jnp.array(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: Array): + """ + Find the number of leap seconds required for the specified date. + """ + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/compat/torch.py b/bilby/gw/compat/torch.py new file mode 100644 index 000000000..b3958f347 --- /dev/null +++ b/bilby/gw/compat/torch.py @@ -0,0 +1,19 @@ +import torch +from plum import dispatch + +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = torch.tensor(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: torch.Tensor) -> torch.Tensor: + """ + Find the number of leap seconds required for the specified date. + """ + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 9bd9cab06..96cd02dd1 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -26,6 +26,7 @@ lalsim_SimNeutronStarRadius, lalsim_SimNeutronStarLoveNumberK2) +from ..compat.utils import array_module from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, gravitational_constant, speed_of_light, command_line_args, safe_file_dump from ..core.prior import DeltaFunction @@ -204,9 +205,9 @@ def convert_to_lal_binary_black_hole_parameters(parameters): added_keys: list keys which are added to parameters during function call """ - converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) + xp = array_module(parameters) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ @@ -244,7 +245,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters['a_{}'.format(idx)] = abs( converted_parameters[key]) converted_parameters['cos_tilt_{}'.format(idx)] = \ - np.sign(converted_parameters[key]) + xp.sign(xp.asarray(converted_parameters[key])) else: with np.errstate(invalid="raise"): try: @@ -267,13 +268,13 @@ def convert_to_lal_binary_black_hole_parameters(parameters): cos_angle = str('cos_' + angle) if cos_angle in converted_parameters.keys(): with np.errstate(invalid="ignore"): - converted_parameters[angle] = np.arccos(converted_parameters[cos_angle]) + converted_parameters[angle] = xp.arccos(converted_parameters[cos_angle]) - if "delta_phase" in original_keys: + if "delta_phase" in converted_parameters: with np.errstate(invalid="ignore"): - converted_parameters["phase"] = np.mod( + converted_parameters["phase"] = xp.mod( converted_parameters["delta_phase"] - - np.sign(np.cos(converted_parameters["theta_jn"])) + - xp.sign(xp.cos(converted_parameters["theta_jn"])) * converted_parameters["psi"], 2 * np.pi) added_keys = [key for key in converted_parameters.keys() @@ -378,19 +379,19 @@ def convert_to_lal_binary_neutron_star_parameters(parameters): g3pca = converted_parameters['eos_spectral_pca_gamma_3'] m1s = converted_parameters['mass_1_source'] m2s = converted_parameters['mass_2_source'] - all_lambda_1 = np.empty(0) - all_lambda_2 = np.empty(0) - all_eos_check = np.empty(0, dtype=bool) + all_lambda_1 = list() + all_lambda_2 = list() + all_eos_check = list() for (g_0pca, g_1pca, g_2pca, g_3pca, m1_s, m2_s) in zip(g0pca, g1pca, g2pca, g3pca, m1s, m2s): g_0, g_1, g_2, g_3 = spectral_pca_to_spectral(g_0pca, g_1pca, g_2pca, g_3pca) lambda_1, lambda_2, eos_check = \ spectral_params_to_lambda_1_lambda_2(g_0, g_1, g_2, g_3, m1_s, m2_s) - all_lambda_1 = np.append(all_lambda_1, lambda_1) - all_lambda_2 = np.append(all_lambda_2, lambda_2) - all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + all_lambda_1.append(lambda_1) + all_lambda_2.append(lambda_2) + all_eos_check.append(eos_check) + converted_parameters['lambda_1'] = np.array(all_lambda_1) + converted_parameters['lambda_2'] = np.array(all_lambda_2) + converted_parameters['eos_check'] = np.array(all_eos_check) for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] elif 'eos_polytrope_gamma_0' and 'eos_polytrope_log10_pressure_1' in converted_parameters.keys(): @@ -630,8 +631,9 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) array of gamma_0, gamma_1, gamma_2, gamma_3 in model space ''' - sampled_pca_gammas = np.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) - transformation_matrix = np.array( + xp = array_module(gamma_pca_0) + sampled_pca_gammas = xp.asarray([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) + transformation_matrix = xp.asarray( [ [0.43801, -0.76705, 0.45143, 0.12646], [-0.53573, 0.17169, 0.67968, 0.47070], @@ -640,10 +642,10 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) ] ) - model_space_mean = np.array([0.89421, 0.33878, -0.07894, 0.00393]) - model_space_standard_deviation = np.array([0.35700, 0.25769, 0.05452, 0.00312]) + model_space_mean = xp.asarray([0.89421, 0.33878, -0.07894, 0.00393]) + model_space_standard_deviation = xp.asarray([0.35700, 0.25769, 0.05452, 0.00312]) converted_gamma_parameters = \ - model_space_mean + model_space_standard_deviation * np.dot(transformation_matrix, sampled_pca_gammas) + model_space_mean + model_space_standard_deviation * (transformation_matrix @ sampled_pca_gammas) return converted_gamma_parameters @@ -958,9 +960,9 @@ def chirp_mass_and_primary_mass_to_mass_ratio(chirp_mass, mass_1): Mass ratio (mass_2/mass_1) of the binary """ a = (chirp_mass / mass_1) ** 5 - t0 = np.cbrt(9 * a + np.sqrt(3) * np.sqrt(27 * a ** 2 - 4 * a ** 3)) - t1 = np.cbrt(2) * 3 ** (2 / 3) - t2 = np.cbrt(2 / 3) * a + t0 = (9 * a + 3**0.5 * (27 * a ** 2 - 4 * a ** 3)**0.5)**(1 / 3) + t1 = (2)**(1 / 3) * 3 ** (2 / 3) + t2 = (2 / 3)**(1 / 3) * a return t2 / t0 + t0 / t1 @@ -1043,8 +1045,8 @@ def component_masses_to_symmetric_mass_ratio(mass_1, mass_2): symmetric_mass_ratio: float Symmetric mass ratio of the binary """ - - return np.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, 1 / 4) + xp = array_module(mass_1) + return xp.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, xp.asarray(0.25)) def component_masses_to_mass_ratio(mass_1, mass_2): @@ -1403,17 +1405,17 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s lambda_antisymmetric: float Antisymmetric tidal parameter. """ - lambda_symmetric_m1o5 = np.power(lambda_symmetric, -1. / 5.) + lambda_symmetric_m1o5 = lambda_symmetric ** (-1 / 5) lambda_symmetric_m2o5 = lambda_symmetric_m1o5 * lambda_symmetric_m1o5 lambda_symmetric_m3o5 = lambda_symmetric_m2o5 * lambda_symmetric_m1o5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # Eqn.2 from CHZ, incorporating the dependence on mass ratio n_polytropic = 0.743 # average polytropic index for the EoSs included in the fit - q_for_Fnofq = np.power(q, 10. / (3. - n_polytropic)) + q_for_Fnofq = q ** (10. / (3. - n_polytropic)) Fnofq = (1. - q_for_Fnofq) / (1. + q_for_Fnofq) # b_ij and c_ij coefficients are given in Table I of CHZ @@ -1483,10 +1485,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin lambda_antisymmetric_fitOnly = binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_symmetric, mass_ratio) - lambda_symmetric_sqrt = np.sqrt(lambda_symmetric) + lambda_symmetric_sqrt = lambda_symmetric ** 0.5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # mu_i and sigma_i coefficients are given in Table II of CHZ @@ -1546,9 +1548,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin # Eqn 5 from CHZ, averaging the corrections from the # standard deviations of the residual fits - lambda_antisymmetric_stdCorr = \ - np.sqrt(np.square(lambda_antisymmetric_lambda_symmetric_stdCorr) + - np.square(lambda_antisymmetric_mass_ratio_stdCorr)) + lambda_antisymmetric_stdCorr = ( + lambda_antisymmetric_lambda_symmetric_stdCorr ** 2 + + lambda_antisymmetric_mass_ratio_stdCorr ** 2 + ) ** 0.5 # Draw a correction on the fit from a # Gaussian distribution with width lambda_antisymmetric_stdCorr @@ -2066,28 +2069,29 @@ def generate_spin_parameters(sample): output_sample = sample.copy() output_sample = generate_component_spins(output_sample) + xp = array_module(sample) output_sample['chi_eff'] = (output_sample['spin_1z'] + output_sample['spin_2z'] * output_sample['mass_ratio']) /\ (1 + output_sample['mass_ratio']) - output_sample['chi_1_in_plane'] = np.sqrt( + output_sample['chi_1_in_plane'] = ( output_sample['spin_1x'] ** 2 + output_sample['spin_1y'] ** 2 - ) - output_sample['chi_2_in_plane'] = np.sqrt( + ) ** 0.5 + output_sample['chi_2_in_plane'] = ( output_sample['spin_2x'] ** 2 + output_sample['spin_2y'] ** 2 - ) + ) ** 0.5 - output_sample['chi_p'] = np.maximum( + output_sample['chi_p'] = xp.maximum( output_sample['chi_1_in_plane'], (4 * output_sample['mass_ratio'] + 3) / (3 * output_sample['mass_ratio'] + 4) * output_sample['mass_ratio'] * output_sample['chi_2_in_plane']) try: - output_sample['cos_tilt_1'] = np.cos(output_sample['tilt_1']) - output_sample['cos_tilt_2'] = np.cos(output_sample['tilt_2']) + output_sample['cos_tilt_1'] = xp.cos(output_sample['tilt_1']) + output_sample['cos_tilt_2'] = xp.cos(output_sample['tilt_2']) except KeyError: pass @@ -2116,12 +2120,13 @@ def generate_component_spins(sample): ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', 'mass_1', 'mass_2', 'reference_frequency', 'phase'] if all(key in output_sample.keys() for key in spin_conversion_parameters): + xp = array_module(output_sample["theta_jn"]) ( output_sample['iota'], output_sample['spin_1x'], output_sample['spin_1y'], output_sample['spin_1z'], output_sample['spin_2x'], output_sample['spin_2y'], output_sample['spin_2z'] - ) = np.vectorize(bilby_to_lalsimulation_spins)( + ) = xp.vectorize(bilby_to_lalsimulation_spins)( output_sample['theta_jn'], output_sample['phi_jl'], output_sample['tilt_1'], output_sample['tilt_2'], output_sample['phi_12'], output_sample['a_1'], output_sample['a_2'], @@ -2131,10 +2136,10 @@ def generate_component_spins(sample): ) output_sample['phi_1'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_1y'], output_sample['spin_1x']), 2 * np.pi) output_sample['phi_2'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_2y'], output_sample['spin_2x']), 2 * np.pi) elif 'chi_1' in output_sample and 'chi_2' in output_sample: diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 729b9e332..a4cebffe6 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -42,10 +42,13 @@ import copy import os +import array_api_compat as aac import numpy as np import pandas as pd +from array_api_compat import is_jax_namespace from scipy.interpolate import interp1d +from ...compat.utils import array_module, xp_wrap from ...core.utils.log import logger from ...core.prior.dict import PriorDict from ..prior import CalibrationPriorDict @@ -330,9 +333,11 @@ def __repr__(self): def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): """Evaluate Eq. (1) in https://dcc.ligo.org/LIGO-T2300140""" - parameters = np.array([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) + xp = array_module(self.params[f"{kind}_0"]) + parameters = xp.asarray([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) next_nodes = previous_nodes + 1 - spline_coefficients = self.nodes_to_spline_coefficients.dot(parameters) + nodes = xp.asarray(self.nodes_to_spline_coefficients) + spline_coefficients = nodes.dot(parameters) return ( a * parameters[previous_nodes] + b * parameters[next_nodes] @@ -340,7 +345,8 @@ def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): + d * spline_coefficients[next_nodes] ) - def get_calibration_factor(self, frequency_array, **params): + @xp_wrap + def get_calibration_factor(self, frequency_array, *, xp=np, **params): """Apply calibration model Parameters @@ -358,10 +364,11 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - log10f_per_deltalog10f = ( - np.log10(frequency_array) - self.log_spline_points[0] + log10f_per_deltalog10f = xp.nan_to_num( + xp.log10(frequency_array) - xp.asarray(self.log_spline_points[0]), + neginf=0.0, ) / self.delta_log_spline_points - previous_nodes = np.clip(np.floor(log10f_per_deltalog10f).astype(int), a_min=0, a_max=self.n_points - 2) + previous_nodes = xp.clip(xp.astype(log10f_per_deltalog10f, int), min=0, max=self.n_points - 2) b = log10f_per_deltalog10f - previous_nodes a = 1 - b c = (a**3 - a) / 6 @@ -373,7 +380,7 @@ def get_calibration_factor(self, frequency_array, **params): delta_phase = self._evaluate_spline("phase", a, b, c, d, previous_nodes) calibration_factor = (1 + delta_amplitude) * (2 + 1j * delta_phase) / (2 - 1j * delta_phase) - return calibration_factor + return xp.nan_to_num(calibration_factor) class Precomputed(Recalibrate): @@ -405,8 +412,21 @@ def get_calibration_factor(self, frequency_array, **params): idx = int(params.get(self.prefix, None)) if idx is None: raise KeyError(f"Calibration index for {self.label} not found.") - if not np.array_equal(frequency_array, self.frequency_array): - raise ValueError("Frequency grid passed to calibrator doesn't match.") + + xp = aac.get_namespace(frequency_array) + if not xp.array_equal(frequency_array, self.frequency_array): + intersection, mask, _ = xp.intersect1d( + frequency_array, self.frequency_array, return_indices=True + ) + if len(intersection) != len(self.frequency_array): + raise ValueError("Frequency grid passed to calibrator doesn't match.") + output = xp.ones_like(frequency_array, dtype=complex) + curve = xp.asarray(self.curves[idx]) + if is_jax_namespace(xp): + output = output.at[mask].set(curve) + else: + output[mask] = curve + return output return self.curves[idx] @classmethod diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index d7e1433de..a6c2df168 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -1,5 +1,5 @@ import numpy as np -from bilby_cython.geometry import calculate_arm, detector_tensor +from ..geometry import calculate_arm, detector_tensor from .. import utils as gwutils @@ -264,7 +264,7 @@ def detector_tensor(self): if not self._x_updated or not self._y_updated: _, _ = self.x, self.y # noqa if not self._detector_tensor_updated: - self._detector_tensor = detector_tensor(x=self.x, y=self.y) + self._detector_tensor = detector_tensor(self.x, self.y) self._detector_tensor_updated = True return self._detector_tensor @@ -290,17 +290,27 @@ def unit_vector_along_arm(self, arm): """ if arm == 'x': return calculate_arm( - arm_tilt=self._xarm_tilt, - arm_azimuth=self._xarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._xarm_tilt, + self._xarm_azimuth, + self._longitude, + self._latitude ) elif arm == 'y': return calculate_arm( - arm_tilt=self._yarm_tilt, - arm_azimuth=self._yarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._yarm_tilt, + self._yarm_azimuth, + self._longitude, + self._latitude ) else: raise ValueError("Arm must either be 'x' or 'y'.") + + def set_array_backend(self, xp): + self.length = xp.asarray(self.length) + self.latitude = xp.asarray(self.latitude) + self.longitude = xp.asarray(self.longitude) + self.elevation = xp.asarray(self.elevation) + self.xarm_azimuth = xp.asarray(self.xarm_azimuth) + self.yarm_azimuth = xp.asarray(self.yarm_azimuth) + self.xarm_tilt = xp.asarray(self.xarm_tilt) + self.yarm_tilt = xp.asarray(self.yarm_tilt) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 94267d30c..bf1543f0b 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -1,16 +1,17 @@ import os import numpy as np -from bilby_cython.geometry import ( - get_polarization_tensor, - three_by_three_matrix_contraction, - time_delay_from_geocenter, -) from ...core import utils -from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump +from ...core.utils import PropertyAccessor, docstring, logger, safe_file_dump from ...core.utils.env import string_to_boolean +from ...compat.utils import array_module from .. import utils as gwutils +from ..geometry import ( + get_polarization_tensor, + three_by_three_matrix_contraction, + time_delay_from_geocenter, +) from .calibration import Recalibrate from .geometry import InterferometerGeometry from .strain_data import InterferometerStrainData @@ -114,16 +115,19 @@ def __repr__(self): float(self.geometry.yarm_azimuth), float(self.geometry.xarm_tilt), float(self.geometry.yarm_tilt)) - def set_strain_data_from_gwpy_timeseries(self, time_series): + def set_strain_data_from_gwpy_timeseries(self, time_series, *, xp=None): """ Set the `Interferometer.strain_data` from a gwpy TimeSeries Parameters ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to set. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_gwpy_timeseries(time_series=time_series) + self.strain_data.set_from_gwpy_timeseries(time_series=time_series, xp=xp) def set_strain_data_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -174,7 +178,7 @@ def set_strain_data_from_power_spectral_density( def set_strain_data_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `Interferometer.strain_data` from a frame file Parameters @@ -192,15 +196,18 @@ def set_strain_data_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_frame_file( frame_file=frame_file, sampling_frequency=sampling_frequency, duration=duration, start_time=start_time, - channel=channel, buffer_time=buffer_time) + channel=channel, buffer_time=buffer_time, xp=xp) def set_strain_data_from_channel_name( - self, channel, sampling_frequency, duration, start_time=0): + self, channel, sampling_frequency, duration, start_time=0, *, xp=None): """ Set the `Interferometer.strain_data` by fetching from given channel using strain_data.set_from_channel_name() @@ -215,22 +222,28 @@ def set_strain_data_from_channel_name( The data duration (in s) start_time: float The GPS start-time of the data + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_channel_name( channel=channel, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + duration=duration, start_time=start_time, xp=xp) - def set_strain_data_from_csv(self, filename): + def set_strain_data_from_csv(self, filename, *, xp=None): """ Set the `Interferometer.strain_data` from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_csv(filename) + self.strain_data.set_from_csv(filename, xp=xp) def set_strain_data_from_zero_noise( self, sampling_frequency, duration, start_time=0): @@ -312,11 +325,14 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= used to set the time at which the antenna response is evaluated, otherwise the provided :code:`Parameters["geocent_time"]` is used. """ + xp = array_module(waveform_polarizations) if frequencies is None: - frequencies = self.frequency_array[self.frequency_mask] + # frequencies = self.frequency_array[self.frequency_mask] + frequencies = self.frequency_array mask = self.frequency_mask else: - mask = np.ones(len(frequencies), dtype=bool) + mask = xp.ones(len(frequencies), dtype=bool) + frequencies = xp.asarray(frequencies) if self.reference_time is None: antenna_time = parameters["geocent_time"] @@ -331,8 +347,8 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= antenna_time, parameters['psi'], mode) - signal[mode] = waveform_polarizations[mode] * det_response - signal_ifo = sum(signal.values()) * mask + signal[mode] = waveform_polarizations[mode] * mask * det_response + signal_ifo = sum(signal.values()) time_shift = self.time_delay_from_geocenter( parameters['ra'], parameters['dec'], parameters['geocent_time']) @@ -342,10 +358,12 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= dt_geocent = parameters['geocent_time'] - self.strain_data.start_time dt = dt_geocent + time_shift - signal_ifo[mask] = signal_ifo[mask] * np.exp(-1j * 2 * np.pi * dt * frequencies) + xp = array_module(signal_ifo) + + signal_ifo = signal_ifo * xp.exp(-1j * 2 * np.pi * dt * frequencies) - signal_ifo[mask] *= self.calibration_model.get_calibration_factor( - frequencies, prefix='recalib_{}_'.format(self.name), **parameters + signal_ifo *= self.calibration_model.get_calibration_factor( + frequencies, prefix=f'recalib_{self.name}_', xp=xp, **parameters ) return signal_ifo @@ -494,7 +512,7 @@ def inject_signal_from_waveform_polarizations(self, parameters, injection_polari self.strain_data.frequency_domain_strain += signal_ifo self.meta_data['optimal_SNR'] = ( - np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real) + self.optimal_snr_squared(signal=signal_ifo)).real ** 0.5 self.meta_data['matched_filter_SNR'] = ( self.matched_filter_snr(signal=signal_ifo)) self.meta_data['parameters'] = parameters @@ -680,7 +698,7 @@ def whiten_frequency_series(self, frequency_series : np.array) -> np.array: frequency_series : np.array The frequency series, whitened by the ASD """ - return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)) + return frequency_series / (self.amplitude_spectral_density_array * (self.duration / 4)**0.5) def get_whitened_time_series_from_whitened_frequency_series( self, @@ -711,14 +729,13 @@ def get_whitened_time_series_from_whitened_frequency_series( w = \\sqrt{N W} = \\sqrt{\\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})} """ - frequency_window_factor = ( - np.sum(self.frequency_mask) - / len(self.frequency_mask) - ) + xp = array_module(whitened_frequency_series) + + frequency_window_factor = self.frequency_mask.mean() whitened_time_series = ( - np.fft.irfft(whitened_frequency_series) - * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor + xp.fft.irfft(whitened_frequency_series) + * self.frequency_mask.sum()**0.5 / frequency_window_factor ) return whitened_time_series @@ -936,3 +953,10 @@ def from_pickle(cls, filename=None): if res.__class__ != cls: raise TypeError('The loaded object is not an Interferometer') return res + + def set_array_backend(self, xp): + self.geometry.set_array_backend(xp=xp) + + @property + def array_backend(self): + return array_module(self.geometry.length) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 25b3e7e71..4ac52454a 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -5,6 +5,7 @@ from ...core import utils from ...core.utils import logger, safe_file_dump +from ..geometry import zenith_azimuth_to_theta_phi from .interferometer import Interferometer from .psd import PowerSpectralDensity @@ -341,6 +342,14 @@ def from_pickle(cls, filename=None): ) from_pickle.__doc__ = _load_docstring.format(format="pickle") + def set_array_backend(self, xp): + for ifo in self: + ifo.set_array_backend(xp) + + @property + def array_backend(self): + return self[0].array_backend + class TriangularInterferometer(InterferometerList): def __init__( @@ -472,3 +481,26 @@ def load_interferometer(filename): "{} could not be loaded. Invalid parameter 'shape'.".format(filename) ) return ifo + + +@zenith_azimuth_to_theta_phi.dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos: InterferometerList | list): + """ + Convert from the 'detector frame' to the Earth frame. + + Parameters + ========== + kappa: float + The zenith angle in the detector frame + eta: float + The azimuthal angle in the detector frame + ifos: list + List of Interferometer objects defining the detector frame + + Returns + ======= + theta, phi: float + The zenith and azimuthal angles in the earth frame. + """ + delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex + return zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) diff --git a/bilby/gw/detector/psd.py b/bilby/gw/detector/psd.py index a3948f966..e3fe7091a 100644 --- a/bilby/gw/detector/psd.py +++ b/bilby/gw/detector/psd.py @@ -3,6 +3,7 @@ import numpy as np from scipy.interpolate import interp1d +from ...compat.utils import xp_wrap from ...core import utils from ...core.utils import logger from .strain_data import InterferometerStrainData @@ -341,7 +342,8 @@ def __import_power_spectral_density(self): """ Automagically load a power spectral density curve """ self.frequency_array, self.psd_array = np.genfromtxt(self.psd_file).T - def get_noise_realisation(self, sampling_frequency, duration): + @xp_wrap + def get_noise_realisation(self, sampling_frequency, duration, *, xp=None): """ Generate frequency Gaussian noise scaled to the power spectral density. @@ -363,4 +365,4 @@ def get_noise_realisation(self, sampling_frequency, duration): frequency_domain_strain = self.__power_spectral_density_interpolated(frequencies) ** 0.5 * white_noise out_of_bounds = (frequencies < min(self.frequency_array)) | (frequencies > max(self.frequency_array)) frequency_domain_strain[out_of_bounds] = 0 * (1 + 1j) - return frequency_domain_strain, frequencies + return xp.asarray(frequency_domain_strain), xp.asarray(frequencies) diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index bca7acced..017d2ea50 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -1,5 +1,7 @@ +import array_api_compat as aac import numpy as np +from ...compat.utils import array_module from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries from ...core.utils import logger, PropertyAccessor @@ -498,7 +500,7 @@ def set_from_time_domain_strain( else: raise ValueError("Data times do not match time array") - def set_from_gwpy_timeseries(self, time_series): + def set_from_gwpy_timeseries(self, time_series, *, xp=np): """ Set the strain data from a gwpy TimeSeries This sets the time_domain_strain attribute, the frequency_domain_strain @@ -509,17 +511,23 @@ def set_from_gwpy_timeseries(self, time_series): ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to use + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries logger.debug('Setting data using provided gwpy TimeSeries object') if not isinstance(time_series, TimeSeries): raise ValueError("Input time_series is not a gwpy TimeSeries") + duration = xp.asarray(time_series.duration.value) + sampling_frequency = xp.asarray(time_series.sample_rate.value) + start_time = xp.asarray(time_series.epoch.value) self._times_and_frequencies = \ - CoupledTimeAndFrequencySeries(duration=time_series.duration.value, - sampling_frequency=time_series.sample_rate.value, - start_time=time_series.epoch.value) - self._time_domain_strain = time_series.value + CoupledTimeAndFrequencySeries(duration=duration, + sampling_frequency=sampling_frequency, + start_time=start_time) + self._time_domain_strain = xp.asarray(time_series.value) self._frequency_domain_strain = None self._channel = time_series.channel @@ -529,7 +537,7 @@ def channel(self): def set_from_open_data( self, name, start_time, duration=4, outdir='outdir', cache=True, - **kwargs): + *, xp=None, **kwargs): """ Set the strain data from open LOSC data This sets the time_domain_strain attribute, the frequency_domain_strain @@ -548,30 +556,38 @@ def set_from_open_data( Directory where the psd files are saved cache: bool, optional Whether or not to store/use the acquired data. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. **kwargs: All keyword arguments are passed to `gwpy.timeseries.TimeSeries.fetch_open_data()`. """ - timeseries = gwutils.get_open_strain_data( - name, start_time, start_time + duration, outdir=outdir, cache=cache, + name, float(start_time), float(start_time + duration), outdir=outdir, cache=cache, **kwargs) - self.set_from_gwpy_timeseries(timeseries) + if xp is None: + xp = array_module((duration, start_time)) + + self.set_from_gwpy_timeseries(timeseries, xp=xp) - def set_from_csv(self, filename): + def set_from_csv(self, filename, xp=None): """ Set the strain data from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries timeseries = TimeSeries.read(filename, format='csv') - self.set_from_gwpy_timeseries(timeseries) + self.set_from_gwpy_timeseries(timeseries, xp=xp) def set_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -661,12 +677,13 @@ def set_from_zero_noise(self, sampling_frequency, duration, start_time=0): sampling_frequency=sampling_frequency, start_time=start_time) logger.debug('Setting zero noise data') - self._frequency_domain_strain = np.zeros_like(self.frequency_array, + xp = aac.get_namespace(self.frequency_array) + self._frequency_domain_strain = xp.zeros_like(self.frequency_array, dtype=complex) def set_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `frequency_domain_strain` from a frame fiile Parameters @@ -684,6 +701,10 @@ def set_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ @@ -697,9 +718,12 @@ def set_from_frame_file( buffer_time=buffer_time, channel=channel, resample=sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) - def set_from_channel_name(self, channel, duration, start_time, sampling_frequency): + self.set_from_gwpy_timeseries(strain, xp=xp) + + def set_from_channel_name(self, channel, duration, start_time, sampling_frequency, *, xp=None): """ Set the `frequency_domain_strain` by fetching from given channel using gwpy.TimesSeries.get(), which dynamically accesses either frames on disk, or a remote NDS2 server to find and return data. This function @@ -715,6 +739,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc The GPS start-time of the data sampling_frequency: float The sampling frequency (in Hz) + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ from gwpy.timeseries import TimeSeries @@ -730,7 +758,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc strain = TimeSeries.get(channel, start_time, start_time + duration) strain = strain.resample(sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) + + self.set_from_gwpy_timeseries(strain, xp=xp) class Notch(object): diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py new file mode 100644 index 000000000..68321d4b4 --- /dev/null +++ b/bilby/gw/geometry.py @@ -0,0 +1,195 @@ +from plum import dispatch + +from .time import greenwich_mean_sidereal_time +from ..compat.utils import array_module, promote_to_array + + +__all__ = [ + "antenna_response", + "calculate_arm", + "detector_tensor", + "get_polarization_tensor", + "get_polarization_tensor_multiple_modes", + "rotation_matrix_from_delta", + "three_by_three_matrix_contraction", + "time_delay_geocentric", + "time_delay_from_geocenter", + "zenith_azimuth_to_theta_phi", +] + + +@dispatch +def antenna_response(detector_tensor, ra, dec, time, psi, mode): + """""" + xp = array_module(detector_tensor) + polarization_tensor = get_polarization_tensor(*promote_to_array((ra, dec, time, psi), xp), mode) + return three_by_three_matrix_contraction(detector_tensor, polarization_tensor) + + +@dispatch +def calculate_arm(arm_tilt, arm_azimuth, longitude, latitude): + """""" + xp = array_module(arm_tilt) + e_long = xp.asarray([-xp.sin(longitude), xp.cos(longitude), longitude * 0]) + e_lat = xp.asarray( + [ + -xp.sin(latitude) * xp.cos(longitude), + -xp.sin(latitude) * xp.sin(longitude), + xp.cos(latitude), + ] + ) + e_h = xp.asarray( + [ + xp.cos(latitude) * xp.cos(longitude), + xp.cos(latitude) * xp.sin(longitude), + xp.sin(latitude), + ] + ) + + return ( + xp.cos(arm_tilt) * xp.cos(arm_azimuth) * e_long + + xp.cos(arm_tilt) * xp.sin(arm_azimuth) * e_lat + + xp.sin(arm_tilt) * e_h + ) + + +@dispatch +def detector_tensor(x, y): + """""" + xp = array_module(x) + return (xp.outer(x, x) - xp.outer(y, y)) / 2 + + +@dispatch +def get_polarization_tensor(ra, dec, time, psi, mode): + """""" + from functools import partial + + xp = array_module(ra) + + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + phi = ra - gmst + theta = xp.atleast_1d(xp.pi / 2 - dec).squeeze() + u = xp.asarray( + [ + xp.cos(phi) * xp.cos(theta), + xp.cos(theta) * xp.sin(phi), + -xp.sin(theta) * xp.ones_like(phi), + ] + ) + v = xp.asarray([ + -xp.sin(phi), xp.cos(phi), xp.zeros_like(phi) + ]) * xp.ones_like(theta) + omega = xp.asarray([ + xp.sin(xp.pi - theta) * xp.cos(xp.pi + phi), + xp.sin(xp.pi - theta) * xp.sin(xp.pi + phi), + xp.cos(xp.pi - theta) * xp.ones_like(phi), + ]) + m = -u * xp.sin(psi) - v * xp.cos(psi) + n = -u * xp.cos(psi) + v * xp.sin(psi) + if xp.__name__ == "mlx.core": + einsum_shape = "i,j->ij" + else: + einsum_shape = "i...,j...->ij..." + product = partial(xp.einsum, einsum_shape) + + match mode.lower(): + case "plus": + return product(m, m) - product(n, n) + case "cross": + return product(m, n) + product(n, m) + case "breathing": + return product(m, m) + product(n, n) + case "longitudinal": + return product(omega, omega) + case "x": + return product(m, omega) + product(omega, m) + case "y": + return product(n, omega) + product(omega, n) + case _: + raise ValueError(f"{mode} not a polarization mode!") + + +@dispatch +def get_polarization_tensor_multiple_modes(ra, dec, time, psi, modes): + """""" + return [get_polarization_tensor(ra, dec, time, psi, mode) for mode in modes] + + +@dispatch +def rotation_matrix_from_delta(delta_x): + """""" + xp = array_module(delta_x) + delta_x = delta_x / (delta_x**2).sum() ** 0.5 + alpha = xp.arctan2(-delta_x[1] * delta_x[2], delta_x[0]) + beta = xp.arccos(delta_x[2]) + gamma = xp.arctan2(delta_x[1], delta_x[0]) + rotation_1 = xp.asarray( + [ + [xp.cos(alpha), -xp.sin(alpha), xp.zeros(alpha.shape)], + [xp.sin(alpha), xp.cos(alpha), xp.zeros(alpha.shape)], + [xp.zeros(alpha.shape), xp.zeros(alpha.shape), xp.ones(alpha.shape)], + ] + ) + rotation_2 = xp.asarray( + [ + [xp.cos(beta), xp.zeros(beta.shape), xp.sin(beta)], + [xp.zeros(beta.shape), xp.ones(beta.shape), xp.zeros(beta.shape)], + [-xp.sin(beta), xp.zeros(beta.shape), xp.cos(beta)], + ] + ) + rotation_3 = xp.asarray( + [ + [xp.cos(gamma), -xp.sin(gamma), xp.zeros(gamma.shape)], + [xp.sin(gamma), xp.cos(gamma), xp.zeros(gamma.shape)], + [xp.zeros(gamma.shape), xp.zeros(gamma.shape), xp.ones(gamma.shape)], + ] + ) + return rotation_3 @ rotation_2 @ rotation_1 + + +@dispatch +def three_by_three_matrix_contraction(a, b): + """""" + xp = array_module(a) + return xp.einsum("ij,ij->", a, b) + + +@dispatch +def time_delay_geocentric(detector1, detector2, ra, dec, time): + """""" + xp = array_module(detector1) + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + speed_of_light = 299792458.0 + phi = ra - gmst + theta = xp.pi / 2 - dec + omega = xp.asarray( + [xp.sin(theta) * xp.cos(phi), xp.sin(theta) * xp.sin(phi), xp.cos(theta)] + ) + delta_d = detector2 - detector1 + return omega @ delta_d / speed_of_light + + +@dispatch +def time_delay_from_geocenter(detector1, ra, dec, time): + """""" + xp = array_module(detector1) + return time_delay_geocentric(detector1, xp.zeros(3), ra, dec, time) + + +@dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): + """""" + xp = array_module(delta_x) + omega_prime = xp.stack( + [ + xp.sin(zenith) * xp.cos(azimuth), + xp.sin(zenith) * xp.sin(azimuth), + xp.cos(zenith), + ] + ) + rotation_matrix = rotation_matrix_from_delta(delta_x) + omega = rotation_matrix @ omega_prime + theta = xp.arccos(omega[2]) + phi = xp.arctan2(omega[1], omega[0]) % (2 * xp.pi) + return theta, phi diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index e1778ddb1..ab9c36854 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -2,6 +2,7 @@ import os import copy +import array_api_compat as aac import attr import numpy as np from scipy.special import logsumexp @@ -107,9 +108,13 @@ class GravitationalWaveTransient(Likelihood): @attr.s(slots=True, weakref_slot=False) class _CalculatedSNRs: - d_inner_h = attr.ib(default=0j, converter=complex) - optimal_snr_squared = attr.ib(default=0, converter=float) - complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + # the complex converted breaks JAX compilation + # d_inner_h = attr.ib(default=0j, converter=complex) + # optimal_snr_squared = attr.ib(default=0, converter=float) + # complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + d_inner_h = attr.ib(default=0j) + optimal_snr_squared = attr.ib(default=0) + complex_matched_filter_snr = attr.ib(default=0j) d_inner_h_array = attr.ib(default=None) optimal_snr_squared_array = attr.ib(default=None) @@ -153,6 +158,7 @@ def __init__( self.waveform_generator = waveform_generator super(GravitationalWaveTransient, self).__init__() self.interferometers = InterferometerList(interferometers) + self.interferometers.set_array_backend(interferometers.array_backend) self.time_marginalization = time_marginalization self.distance_marginalization = distance_marginalization self.phase_marginalization = phase_marginalization @@ -165,6 +171,7 @@ def __init__( if "geocent" not in time_reference: self.time_reference = time_reference self.reference_ifo = get_empty_interferometer(self.time_reference) + self.reference_ifo.set_array_backend(self.interferometers.array_backend) if self.time_marginalization: logger.info("Cannot marginalise over non-geocenter time.") self.time_marginalization = False @@ -291,49 +298,50 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr optimal_snr_squared_array = None normalization = 4 / self.waveform_generator.duration + xp = aac.array_namespace(signal) if return_array is False: d_inner_h_array = None optimal_snr_squared_array = None elif self.time_marginalization and self.calibration_marginalization: - d_inner_h_integrand = np.tile( - interferometer.frequency_domain_strain.conjugate() * signal / + d_inner_h_integrand = xp.tile( + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( d_inner_h_integrand[0:-1], axis=0 ).T optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) elif self.time_marginalization and not self.calibration_marginalization: - d_inner_h_array = normalization * np.fft.fft( + d_inner_h_array = normalization * xp.fft.fft( signal[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1] ) elif self.calibration_marginalization and ('recalib_index' not in parameters): d_inner_h_integrand = ( normalization * - interferometer.frequency_domain_strain.conjugate() * signal + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array ) - d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) + d_inner_h_array = xp.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) @@ -392,12 +400,12 @@ def _calculate_noise_log_likelihood(self): log_l = 0 for interferometer in self.interferometers: mask = interferometer.frequency_mask - log_l -= noise_weighted_inner_product( + log_l -= abs(noise_weighted_inner_product( interferometer.frequency_domain_strain[mask], interferometer.frequency_domain_strain[mask], interferometer.power_spectral_density_array[mask], - self.waveform_generator.duration) / 2 - return float(np.real(log_l)) + self.waveform_generator.duration) / 2) + return log_l def noise_log_likelihood(self): # only compute likelihood if called for the 1st time @@ -410,6 +418,7 @@ def log_likelihood_ratio(self, parameters=None): parameters = copy.deepcopy(parameters) else: parameters = _fallback_to_parameters(self, parameters) + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if waveform_polarizations is None: @@ -418,8 +427,6 @@ def log_likelihood_ratio(self, parameters=None): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - total_snrs = self._CalculatedSNRs() for interferometer in self.interferometers: @@ -436,7 +443,7 @@ def log_likelihood_ratio(self, parameters=None): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] -= parameters['time_jitter'] - return float(log_l.real) + return log_l.real def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): parameters = _fallback_to_parameters(self, parameters) @@ -472,14 +479,13 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): def compute_per_detector_log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - for interferometer in self.interferometers: per_detector_snr = self.calculate_snrs( waveform_polarizations=waveform_polarizations, @@ -779,12 +785,12 @@ def distance_marginalized_likelihood(self, d_inner_h, h_inner_h, parameters=None d_inner_h_ref, h_inner_h_ref = self._setup_rho( d_inner_h, h_inner_h, parameters=parameters) if self.phase_marginalization: - d_inner_h_ref = np.abs(d_inner_h_ref) + d_inner_h_ref = abs(d_inner_h_ref) else: - d_inner_h_ref = np.real(d_inner_h_ref) + d_inner_h_ref = d_inner_h_ref.real return self._interp_dist_margd_loglikelihood( - d_inner_h_ref, h_inner_h_ref, grid=False) + d_inner_h_ref, h_inner_h_ref) def phase_marginalized_likelihood(self, d_inner_h, h_inner_h): d_inner_h = ln_i0(abs(d_inner_h)) @@ -800,14 +806,15 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters if self.jitter_time: times = self._times + parameters['time_jitter'] - _time_prior = self.priors['geocent_time'] - time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) - times = times[time_mask] + if not aac.is_jax_array(d_inner_h_tc_array): + _time_prior = self.priors['geocent_time'] + time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) + times = times[time_mask] + if self.calibration_marginalization: + d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] + else: + d_inner_h_tc_array = d_inner_h_tc_array[time_mask] time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc - if self.calibration_marginalization: - d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] - else: - d_inner_h_tc_array = d_inner_h_tc_array[time_mask] if self.distance_marginalization: log_l_tc_array = self.distance_marginalized_likelihood( @@ -817,9 +824,9 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h) elif self.calibration_marginalization: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h[:, np.newaxis] / 2 else: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h / 2 return logsumexp(log_l_tc_array, b=time_prior_array, axis=-1) def get_calibration_log_likelihoods(self, signal_polarizations=None, parameters=None): @@ -933,8 +940,11 @@ def _setup_distance_marginalization(self, lookup_table=None): else: self._create_lookup_table() self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( - self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, - self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) + self._d_inner_h_ref_array, + self._optimal_snr_squared_ref_array, + self._dist_margd_loglikelihood_array.T, + fill_value=-np.inf, + ) @property def cached_lookup_table_filename(self): @@ -1088,6 +1098,8 @@ def reference_frame(self, frame): self._reference_frame = InterferometerList([frame[:2], frame[2:4]]) else: raise ValueError("Unable to parse reference frame {}".format(frame)) + if isinstance(self._reference_frame, InterferometerList): + self._reference_frame.set_array_backend(self.interferometers.array_backend) def get_sky_frame_parameters(self, parameters=None): """ diff --git a/bilby/gw/likelihood/basic.py b/bilby/gw/likelihood/basic.py index 2931ec742..1a4c495a6 100644 --- a/bilby/gw/likelihood/basic.py +++ b/bilby/gw/likelihood/basic.py @@ -43,10 +43,11 @@ def noise_log_likelihood(self): """ log_l = 0 for interferometer in self.interferometers: - log_l -= 2. / self.waveform_generator.duration * np.sum( - abs(interferometer.frequency_domain_strain) ** 2 / - interferometer.power_spectral_density_array) - return log_l.real + log_l -= 2. / self.waveform_generator.duration * ( + abs(interferometer.frequency_domain_strain) ** 2 + / interferometer.power_spectral_density_array + ).sum() + return log_l def log_likelihood(self, parameters=None): """ Calculates the real part of log-likelihood value @@ -87,8 +88,9 @@ def log_likelihood_interferometer(self, waveform_polarizations, signal_ifo = interferometer.get_detector_response( waveform_polarizations, parameters) - log_l = - 2. / self.waveform_generator.duration * np.vdot( - interferometer.frequency_domain_strain - signal_ifo, - (interferometer.frequency_domain_strain - signal_ifo) / - interferometer.power_spectral_density_array) + residual = interferometer.frequency_domain_strain - signal_ifo + + log_l = - 2. / self.waveform_generator.duration * ( + abs(residual)**2 / interferometer.power_spectral_density_array + ).sum() return log_l.real diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 056864334..d09e186fb 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -5,6 +5,7 @@ import numpy as np from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, speed_of_light, solar_mass, radius_of_earth, gravitational_constant, round_up_to_power_of_two, @@ -533,8 +534,10 @@ def _setup_linear_coefficients(self): for ifo in self.interferometers: logger.info("Pre-computing linear coefficients for {}".format(ifo.name)) fddata = np.zeros(N // 2 + 1, dtype=complex) - fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += \ + fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += np.asarray( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) + for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) @@ -551,7 +554,7 @@ def _setup_quadratic_coefficients_linear_interp(self): linear-interpolation algorithm""" logger.info("Linear-interpolation algorithm is used for (h, h).") self.quadratic_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers) - original_duration = self.interferometers.duration + original_duration = float(self.interferometers.duration) for b in range(self.number_of_bands): logger.info(f"Pre-computing quadratic coefficients for the {b}-th band") @@ -575,7 +578,7 @@ def _setup_quadratic_coefficients_linear_interp(self): start_idx_in_band + len(window_sequence) - 1, len(ifo.power_spectral_density_array) - 1 ) - _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1] + _frequency_mask = np.asarray(ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1]) window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band) window_over_psd[_frequency_mask] = \ 1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask] @@ -710,13 +713,13 @@ def setup_multibanding_from_weights(self, weights): setattr(self, key, value) def _setup_time_marginalization_multiband(self): - """This overwrites attributes set by _setup_time_marginalization of the base likelihood class""" + self._beam_pattern_reference_time = ( + self.priors['geocent_time'].minimum + self.priors['geocent_time'].maximum + ) / 2 N = self.Nbs[-1] // 2 self._delta_tc = self.durations[0] / N - self._times = \ - self.interferometers.start_time + np.arange(N) * self._delta_tc - self.time_prior_array = \ - self.priors['geocent_time'].prob(self._times) * self._delta_tc + self._times = self.interferometers.start_time + np.arange(N) * self._delta_tc + self.time_prior_array = self.priors['geocent_time'].prob(self._times) * self._delta_tc # allocate array which is FFTed at each likelihood evaluation self._full_d_h = np.zeros(N, dtype=complex) # idxs to convert full frequency points to banded frequency points, used for filling _full_d_h. @@ -748,7 +751,6 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr """ parameters = _fallback_to_parameters(self, parameters) - modes = { mode: value[self.unique_to_original_frequencies] for mode, value in waveform_polarizations.items() @@ -757,12 +759,14 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr modes, parameters, frequencies=self.banded_frequency_points ) - d_inner_h = np.conj(np.dot(strain, self.linear_coeffs[interferometer.name])) + d_inner_h = (strain @ self.linear_coeffs[interferometer.name]).conj() + + xp = array_module(strain) if self.linear_interpolation: - optimal_snr_squared = np.vdot( - np.real(strain * np.conjugate(strain)), - self.quadratic_coeffs[interferometer.name] + optimal_snr_squared = xp.vdot( + xp.abs(strain)**2, + xp.asarray(self.quadratic_coeffs[interferometer.name]) ) else: optimal_snr_squared = 0. @@ -771,18 +775,22 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr start_idx, end_idx = self.start_end_idxs[b] Mb = self.Mbs[b] if b == 0: - optimal_snr_squared += (4. / self.interferometers.duration) * np.vdot( - np.real(strain[start_idx:end_idx + 1] * np.conjugate(strain[start_idx:end_idx + 1])), - interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] + optimal_snr_squared += (4. / self.interferometers.duration) * xp.vdot( + xp.abs(strain[start_idx:end_idx + 1])**2, + interferometer.frequency_mask[Ks:Ke + 1] * xp.asarray(self.windows[start_idx:end_idx + 1]) / interferometer.power_spectral_density_array[Ks:Ke + 1]) else: self.wths[interferometer.name][b][Ks:Ke + 1] = ( - self.square_root_windows[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1] + xp.asarray(self.square_root_windows[start_idx:end_idx + 1]) + * strain[start_idx:end_idx + 1] + ) + self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft( + xp.asarray(self.wths[interferometer.name][b]) ) - self.hbcs[interferometer.name][b][-Mb:] = np.fft.irfft(self.wths[interferometer.name][b]) - thbc = np.fft.rfft(self.hbcs[interferometer.name][b]) - optimal_snr_squared += (4. / self.Tbhats[b]) * np.vdot( - np.real(thbc * np.conjugate(thbc)), self.Ibcs[interferometer.name][b]) + thbc = xp.fft.rfft(xp.asarray(self.hbcs[interferometer.name][b])) + print(self.Ibcs[interferometer.name][b]) + optimal_snr_squared += (4. / self.Tbhats[b]) * xp.vdot( + xp.abs(thbc)**2, xp.asarray(self.Ibcs[interferometer.name][b].real)) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) @@ -792,7 +800,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr start_idx, end_idx = self.start_end_idxs[b] self._full_d_h[self._full_to_multiband[start_idx:end_idx + 1]] += \ strain[start_idx:end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx:end_idx + 1] - d_inner_h_array = np.fft.fft(self._full_d_h) + d_inner_h_array = xp.fft.fft(self._full_d_h) else: d_inner_h_array = None diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index f4c72e8ef..5a7b2a539 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -4,6 +4,7 @@ from scipy.optimize import differential_evolution from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import logger from ...core.prior.base import Constraint from ...core.prior import DeltaFunction @@ -253,7 +254,7 @@ def set_fiducial_waveforms(self, parameters): for interferometer in self.interferometers: logger.debug(f"Maximum Frequency is {interferometer.maximum_frequency}") wf = interferometer.get_detector_response(self.fiducial_polarizations, parameters) - wf[interferometer.frequency_array > self.maximum_frequency] = 0 + wf *= interferometer.frequency_array <= self.maximum_frequency self.per_detector_fiducial_waveforms[interferometer.name] = wf def find_maximum_likelihood_parameters(self, parameter_bounds, @@ -327,7 +328,7 @@ def compute_summary_data(self): masked_bin_inds[-1] += 1 masked_strain = interferometer.frequency_domain_strain[mask] - masked_h0 = self.per_detector_fiducial_waveforms[interferometer.name][mask] + masked_h0 = np.asarray(self.per_detector_fiducial_waveforms[interferometer.name][mask]) masked_psd = interferometer.power_spectral_density_array[mask] duration = interferometer.duration a0, b0, a1, b1 = np.zeros((4, self.number_of_bins), dtype=complex) @@ -397,20 +398,21 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr parameters=parameters, ) a0, a1, b0, b1 = self.summary_data[interferometer.name] - d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1)) - h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1))) + d_inner_h = (a0 * r0.conj() + a1 * r1.conj()).sum() + h_inner_h = (b0 * abs(r0) ** 2 + 2 * b1 * (r0 * r1.conj()).real).sum() optimal_snr_squared = h_inner_h complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) if return_array and self.time_marginalization: + xp = array_module(r0) full_waveform = self._compute_full_waveform( signal_polarizations=waveform_polarizations, interferometer=interferometer, parameters=parameters, ) - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( full_waveform[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1]) else: diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index fd3ddae54..f5c02f4a8 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -1,7 +1,9 @@ - +import array_api_compat as aac +import array_api_extra as xpx import numpy as np from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, create_frequency_series, speed_of_light, radius_of_earth ) @@ -271,15 +273,16 @@ def _set_unique_frequency_nodes_and_inverse(self): """Set unique frequency nodes and indices to recover linear and quadratic frequency nodes for each combination of linear and quadratic bases """ + xp = aac.array_namespace(self.interferometers.frequency_array) self._unique_frequency_nodes_and_inverse = [] for idx_linear in range(self.number_of_bases_linear): tmp = [] - frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear] + frequency_nodes_linear = xp.asarray(self.weights['frequency_nodes_linear'][idx_linear]) size_linear = len(frequency_nodes_linear) for idx_quadratic in range(self.number_of_bases_quadratic): - frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic] - frequency_nodes_unique, original_indices = np.unique( - np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), + frequency_nodes_quadratic = xp.asarray(self.weights['frequency_nodes_quadratic'][idx_quadratic]) + frequency_nodes_unique, original_indices = xp.unique( + xp.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), return_inverse=True ) linear_indices = original_indices[:size_linear] @@ -456,10 +459,8 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes'] linear_indices = self.waveform_generator.waveform_arguments['linear_indices'] quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] - size_linear = len(linear_indices) - size_quadratic = len(quadratic_indices) - h_linear = np.zeros(size_linear, dtype=complex) - h_quadratic = np.zeros(size_quadratic, dtype=complex) + h_linear = 0j + h_quadratic = 0j for mode in waveform_polarizations['linear']: response = interferometer.antenna_response( parameters['ra'], parameters['dec'], @@ -470,14 +471,15 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr h_linear += waveform_polarizations['linear'][mode] * response h_quadratic += waveform_polarizations['quadratic'][mode] * response + xp = array_module(h_linear) calib_factor = interferometer.calibration_model.get_calibration_factor( - frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **parameters) + xp.asarray(frequency_nodes), prefix=f'recalib_{interferometer.name}_', xp=xp, **parameters) h_linear *= calib_factor[linear_indices] h_quadratic *= calib_factor[quadratic_indices] - optimal_snr_squared = np.vdot( - np.abs(h_quadratic)**2, - self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] + optimal_snr_squared = xp.vdot( + xp.abs(h_quadratic)**2, + xp.asarray(self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic]) ) dt = interferometer.time_delay_from_geocenter( @@ -486,21 +488,25 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ifo_time = dt_geocent + dt indices, in_bounds = self._closest_time_indices( - ifo_time, self.weights['time_samples']) - if not in_bounds: - logger.debug("SNR calculation error: requested time at edge of ROQ time samples") - d_inner_h = -np.inf - complex_matched_filter_snr = -np.inf - else: - d_inner_h_tc_array = np.einsum( - 'i,ji->j', np.conjugate(h_linear), - self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + ifo_time, xp.asarray(self.weights['time_samples'])) + indices = xp.clip(xp.asarray(indices), 0, len(self.weights['time_samples']) - 1) + d_inner_h_tc_array = xp.einsum( + 'i,ji->j', + xp.conj(h_linear), + xp.asarray( + self.weights[interferometer.name + '_linear'][self.basis_number_linear] + )[indices], + ) + + d_inner_h = self._interp_five_samples( + xp.asarray(self.weights['time_samples'])[indices], d_inner_h_tc_array, ifo_time + ) - d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + with np.errstate(invalid="ignore"): + complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) - with np.errstate(invalid="ignore"): - complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) + d_inner_h += xp.log(in_bounds) + complex_matched_filter_snr += xp.log(in_bounds) if return_array and self.time_marginalization: ifo_times = self._times - interferometer.strain_data.start_time @@ -537,9 +543,10 @@ def _closest_time_indices(time, samples): in_bounds: bool Whether the indices are for valid times """ - closest = int((time - samples[0]) / (samples[1] - samples[0])) + xp = array_module(time) + closest = xp.astype(xp.floor((time - samples[0]) / (samples[1] - samples[0])), int) indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] - in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) + in_bounds = (indices[0] >= 0) & (indices[-1] < len(samples)) return indices, in_bounds @staticmethod @@ -564,13 +571,13 @@ def _interp_five_samples(time_samples, values, time): """ r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. r2 = values[2] - 2. * values[3] + values[4] - a = (time_samples[3] - time) / (time_samples[1] - time_samples[0]) + a = (time_samples[3] - time) / max(time_samples[1] - time_samples[0], 1e-12) b = 1. - a c = (a**3. - a) / 6. d = (b**3. - b) / 6. return a * values[2] + b * values[3] + c * r1 + d * r2 - def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): + def _calculate_d_inner_h_array(self, times, h_linear, ifo_name, *, xp=None): """ Calculate d_inner_h at regularly-spaced time samples. Each value is interpolated from the nearest 5 samples with the algorithm explained in @@ -588,21 +595,23 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): ======= d_inner_h_array: array-like """ + if xp is None: + xp = aac.array_namespace(h_linear) roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space - closest_idxs = np.floor(times_per_roq_time_space).astype(int) + closest_idxs = xp.astype(xp.floor(times_per_roq_time_space), int) # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time # spacing is larger than 5 times the ROQ time spacing. weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] - h_linear_conj = np.conjugate(h_linear) + h_linear_conj = h_linear.conj() if (times[1] - times[0]) / roq_time_space > 5: - d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) - d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) - d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj) - d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj) - d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj) + d_inner_h_m2 = weights_linear[closest_idxs - 2] @ h_linear_conj + d_inner_h_m1 = weights_linear[closest_idxs - 1] @ h_linear_conj + d_inner_h_0 = weights_linear[closest_idxs] @ h_linear_conj + d_inner_h_p1 = weights_linear[closest_idxs + 1] @ h_linear_conj + d_inner_h_p2 = weights_linear[closest_idxs + 2] @ h_linear_conj else: - d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj) + d_inner_h_at_roq_time_samples = weights_linear @ h_linear_conj d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2] d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1] d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs] @@ -654,17 +663,17 @@ def perform_roq_params_check(self, ifo=None): except ValueError: roq_minimum_component_mass = None - if ifo.maximum_frequency > roq_maximum_frequency: + if float(ifo.maximum_frequency) > roq_maximum_frequency: raise BilbyROQParamsRangeError( "Requested maximum frequency {} larger than ROQ basis fhigh {}" .format(ifo.maximum_frequency, roq_maximum_frequency) ) - if ifo.minimum_frequency < roq_minimum_frequency: + if float(ifo.minimum_frequency) < roq_minimum_frequency: raise BilbyROQParamsRangeError( "Requested minimum frequency {} lower than ROQ basis flow {}" .format(ifo.minimum_frequency, roq_minimum_frequency) ) - if ifo.strain_data.duration != roq_segment_length: + if float(ifo.strain_data.duration) != roq_segment_length: raise BilbyROQParamsRangeError( "Requested duration differs from ROQ basis seglen") @@ -710,6 +719,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): linear and quadratic basis """ + xp = aac.array_namespace(self.interferometers.frequency_array) time_space = self._get_time_resolution() number_of_time_samples = int(self.interferometers.duration / time_space) earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space @@ -729,7 +739,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): - self.interferometers.start_time ) / time_space)) ) - self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space + self.weights['time_samples'] = xp.arange(start_idx, end_idx + 1) * float(time_space) logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) # select bases to be used, set prior ranges and frequency nodes if exist @@ -782,10 +792,10 @@ def _set_weights(self, linear_matrix, quadratic_matrix): roq_mask = roq_frequencies >= roq_scaled_minimum_frequency roq_frequencies = roq_frequencies[roq_mask] overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d( - ifo.frequency_array[ifo.frequency_mask], roq_frequencies, + np.asarray(ifo.frequency_array[ifo.frequency_mask]), roq_frequencies, return_indices=True) else: - overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] + overlap_frequencies = np.asarray(ifo.frequency_array[ifo.frequency_mask]) roq_idxs_this_ifo = np.arange( linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1], dtype=int) @@ -841,32 +851,44 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): data_over_psd = {} for ifo in self.interferometers: nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int( - ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration) - data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \ - ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] - try: - import pyfftw - ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') - except ImportError: + ifo.minimum_frequency * self.interferometers.duration) + data_over_psd[ifo.name] = ( + ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] + / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] + ) + xp = array_module(data_over_psd) + if aac.is_numpy_namespace(xp): + try: + import pyfftw + ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') + except ImportError: + pyfftw = None + logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") + ifft_input = np.zeros(number_of_time_samples, dtype=complex) + ifft = np.fft.ifft + else: pyfftw = None - logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") - ifft_input = np.zeros(number_of_time_samples, dtype=complex) - ifft = np.fft.ifft + ifft_input = xp.zeros(number_of_time_samples, dtype=complex) + ifft = xp.fft.ifft for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis']) basis_size = linear_matrix_single.shape[0] for ifo in self.interferometers: - ifft_input[:] *= 0. + if pyfftw: + ifft_input[:] *= 0. + else: + ifft_input *= 0 linear_weights = \ - np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex) + xp.zeros((basis_size, len(self.weights['time_samples'])), dtype=complex) for i in range(basis_size): - basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] - ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) - linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] - linear_weights *= 4. * number_of_time_samples / self.interferometers.duration + basis_element = xp.asarray(linear_matrix_single[i][roq_idxs[ifo.name]]).conj() + ifft_input = xpx.at(ifft_input, nonzero_idxs[ifo.name]).set(data_over_psd[ifo.name] * basis_element) + linear_weights = xpx.at(linear_weights, i).set(ifft(ifft_input)[start_idx:end_idx + 1]) + linear_weights = linear_weights.T + linear_weights *= 4. * number_of_time_samples / float(self.interferometers.duration) self.weights[ifo.name + '_linear'].append(linear_weights) if pyfftw is not None: pyfftw.forget_wisdom() @@ -885,6 +907,7 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_linear'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -892,29 +915,39 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): # prepare time-shifted data, which is multiplied by basis tc_shifted_data = dict() for ifo in self.interferometers: - over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) - over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \ + over_whitened_frequency_data = xp.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) + over_whitened_frequency_data = xpx.at( + over_whitened_frequency_data, xp.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask] + ).set( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] - over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data) - tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) + ) + over_whitened_time_data = xp.fft.irfft(over_whitened_frequency_data) + tc_shifted_data[ifo.name] = xp.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] - fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb - Db = np.fft.rfft( + fs = xp.arange(start_frequency_bin, end_frequency_bin + 1) / Tb + Db = xp.fft.rfft( over_whitened_time_data[-int(2. * fhigh_basis * Tb):] )[start_frequency_bin:end_frequency_bin + 1] start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( - 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb)) + this_data = xp.zeros(len(self.weights['time_samples']), dtype=complex) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = ( + 4. / Tb * Db[:, None] * xp.exp( + 2. * np.pi * 1j * fs[:, None] + * (xp.asarray(self.weights['time_samples'][None, :]) - ifo.duration + Tb) + ) + ) + tc_shifted_data[ifo.name] = xpx.at(tc_shifted_data[ifo.name], sl).set(this_data) + start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis'][()]) for ifo in self.interferometers: - self.weights[ifo.name + '_linear'].append( - np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T) + self.weights[ifo.name + '_linear'].append((linear_matrix_single.conj() @ tc_shifted_data[ifo.name]).T) def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs): """ @@ -936,14 +969,15 @@ def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idx """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: + inv_psd = xp.asarray(1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]) self.weights[ifo.name + '_quadratic'].append( - 4. / ifo.strain_data.duration * np.dot( - quadratic_matrix_single[:, roq_idxs[ifo.name]], - 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]])) + 4. / ifo.strain_data.duration * quadratic_matrix_single[:, roq_idxs[ifo.name]] @ inv_psd + ) del quadratic_matrix_single def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): @@ -960,6 +994,7 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -967,27 +1002,31 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): # prepare coefficients multiplied by basis multibanded_inverse_psd = dict() for ifo in self.interferometers: - inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1) - inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \ - 1. / ifo.power_spectral_density_array[ifo.frequency_mask] - inverse_psd_time = np.fft.irfft(inverse_psd_frequency) - multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension) + inverse_psd_frequency = xp.zeros(int(fhigh_basis * ifo.duration) + 1) + sl = np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask] + inverse_psd_frequency = xpx.at(inverse_psd_frequency, sl).set( + 1. / xp.asarray(ifo.power_spectral_density_array[ifo.frequency_mask]) + ) + inverse_psd_time = xp.fft.irfft(inverse_psd_frequency) + multibanded_inverse_psd[ifo.name] = xp.zeros(basis_dimension) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] number_of_samples_half = int(fhigh_basis * Tb) start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft( - np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = 4. / Tb * xp.fft.rfft( + xp.concat([inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]]) )[start_frequency_bin:end_frequency_bin + 1].real + multibanded_inverse_psd[ifo.name] = xpx.at(multibanded_inverse_psd[ifo.name], sl).set(this_data) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'].append( - np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name])) + quadratic_matrix_single @ multibanded_inverse_psd[ifo.name]) def save_weights(self, filename, format='hdf5'): """ @@ -1203,8 +1242,8 @@ def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations times = self._times if self.jitter_time: times = times + parameters["time_jitter"] - time_prior_array = self.priors['geocent_time'].prob(times) - time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array + time_prior_array = np.asarray(self.priors['geocent_time'].prob(times)) + time_post = np.exp(np.asarray(time_log_like - max(time_log_like))) * time_prior_array time_post /= np.sum(time_post) return random.rng.choice(times, p=time_post) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index e262eaaf3..9127edeb2 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,12 +1,14 @@ import os import copy +import array_api_extra as xpx import numpy as np from scipy.integrate import cumulative_trapezoid, trapezoid, quad from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm +from ..compat.utils import xp_wrap from ..core.prior import ( PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint, conditional_prior_factory, PowerLaw, ConditionalLogUniform, @@ -430,23 +432,24 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', def _integral(q): return -5. * q**(-1. / 5.) * hyp2f1(-2. / 5., -1. / 5., 4. / 5., -q) - def cdf(self, val): + def cdf(self, val, *, xp=np): return (self._integral(val) - self._integral(self.minimum)) / self.norm - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): if self.equal_mass: - val = 2 * np.minimum(val, 1 - val) + val = 2 * xp.minimum(val, 1 - val) return self.icdf(val) - def prob(self, val): + def prob(self, val, *, xp=np): in_prior = (val >= self.minimum) & (val <= self.maximum) with np.errstate(invalid="ignore"): prob = (1. + val)**(2. / 5.) / (val**(6. / 5.)) / self.norm * in_prior return prob - def ln_prob(self, val): + def ln_prob(self, val, *, xp=np): with np.errstate(divide="ignore"): - return np.log(self.prob(val)) + return np.log(self.prob(val, xp=xp)) class AlignedSpin(Interped): @@ -511,7 +514,7 @@ def integrand(aa, chi): after performing the integral over spin orientation using a delta function identity. """ - return a_prior.prob(aa) * z_prior.prob(chi / aa) / aa + return a_prior.prob(aa, xp=None) * z_prior.prob(chi / aa, xp=None) / aa self.num_interp = 10_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) @@ -600,21 +603,26 @@ def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary self.__class__.__name__ = "ConditionalChiInPlane" self.__class__.__qualname__ = "ConditionalChiInPlane" - def prob(self, val, **required_variables): - self.update_conditions(**required_variables) + @xp_wrap + def prob(self, val, *, xp=np, **required_variables): + parameters = self.condition_func(self.reference_params.copy(), **required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) + minimum = parameters.get("minimum", self.minimum) + maximum = parameters.get("maximum", self.maximum) return ( - (val >= self.minimum) * (val <= self.maximum) + (val >= minimum) * (val <= maximum) * val / (chi_aligned ** 2 + val ** 2) - / np.log(self._reference_maximum / chi_aligned) + / xp.log(self._reference_maximum / chi_aligned) ) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=np, **required_variables): with np.errstate(divide="ignore"): - return np.log(self.prob(val, **required_variables)) + return xp.log(self.prob(val, **required_variables)) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=np, **required_variables): r""" .. math:: \text{CDF}(\chi_\per) = N ln(1 + (\chi_\perp / \chi) ** 2) @@ -634,14 +642,15 @@ def cdf(self, val, **required_variables): """ self.update_conditions(**required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) - return np.maximum(np.minimum( + return xp.clip( (val >= self.minimum) * (val <= self.maximum) - * np.log(1 + (val / chi_aligned) ** 2) - / 2 / np.log(self._reference_maximum / chi_aligned) - , 1 - ), 0) + * xp.log(1 + (val / chi_aligned) ** 2) + / 2 / xp.log(self._reference_maximum / chi_aligned), + 0, + 1 + ) - def rescale(self, val, **required_variables): + def rescale(self, val, *, xp=np, **required_variables): r""" .. math:: \text{PPF}(\chi_\perp) = ((a_\max / \chi) ** (2x) - 1) ** 0.5 * \chi @@ -664,9 +673,9 @@ def rescale(self, val, **required_variables): def _condition_function(self, reference_params, **kwargs): with np.errstate(invalid="ignore"): - maximum = np.sqrt( + maximum = ( self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2 - ) + )**0.5 return dict(minimum=0, maximum=maximum) def __repr__(self): @@ -690,13 +699,13 @@ def __init__(self, minimum=-np.inf, maximum=np.inf): super().__init__(minimum=minimum, maximum=maximum, name=None, latex_label=None, unit=None) - def prob(self, val): + def prob(self, val, *, xp=np): """ Returns the result of the equation of state check in the conversion function. """ return val - def ln_prob(self, val): + def ln_prob(self, val, *, xp=np): if val: result = 0.0 @@ -1516,7 +1525,8 @@ def _check_imports(): raise ImportError("Must have healpy installed on this machine to use HealPixMapPrior") return healpy - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=np, **kwargs): """ Overwrites the _rescale method of BaseJoint Prior to rescale a single value from the unitcube onto two values (ra, dec) or 3 (ra, dec, dist) if distance is included @@ -1539,17 +1549,19 @@ def _rescale(self, samp, **kwargs): else: samp = samp[:, 0] pix_rescale = self.inverse_cdf(samp) - sample = np.empty((len(pix_rescale), 2)) - dist_samples = np.empty((len(pix_rescale))) + sample = xp.empty((len(pix_rescale), 2)) + dist_samples = xp.empty((len(pix_rescale))) for i, val in enumerate(pix_rescale): theta, ra = self.hp.pix2ang(self.nside, int(round(val))) dec = 0.5 * np.pi - theta - sample[i, :] = self.draw_from_pixel(ra, dec, int(round(val))) + sample = xpx.at(sample, i).set(xp.asarray(self.draw_from_pixel(ra, dec, int(round(val))))) if self.distance: self.update_distance(int(round(val))) - dist_samples[i] = self.distance_icdf(dist_samp[i]) + dist_samples = xpx.at(dist_samples, i).set( + xp.asarray(self.distance_icdf(dist_samp[i])) + ) if self.distance: - sample = np.vstack([sample[:, 0], sample[:, 1], dist_samples]) + sample = xp.vstack([sample[:, 0], sample[:, 1], dist_samples]) return sample.reshape((-1, self.num_vars)) def update_distance(self, pix_idx): @@ -1595,7 +1607,7 @@ def _check_norm(array): norm = np.finfo(array.dtype).eps return array / norm - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): """ Overwrites the _sample method of BaseJoint Prior. Picks a pixel value according to their probabilities, then uniformly samples ra, and decs that are contained in chosen pixel. If the PriorDist includes distance it then @@ -1626,7 +1638,7 @@ def _sample(self, size, **kwargs): sample[samp, :] = [ra_dec[0], ra_dec[1], dist] else: sample[samp, :] = self.draw_from_pixel(ra, dec, sample_pix[samp]) - return sample.reshape((-1, self.num_vars)) + return xp.asarray(sample.reshape((-1, self.num_vars))) def draw_distance(self, pix): """ @@ -1705,7 +1717,8 @@ def check_in_pixel(self, ra, dec, pix): pixel = self.hp.ang2pix(self.nside, theta, phi) return pix == pixel - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Overwrites the _lnprob method of BaseJoint Prior @@ -1731,11 +1744,13 @@ def _ln_prob(self, samp, lnprob, outbounds): phi, dec = samp[0] theta = 0.5 * np.pi - dec pixel = self.hp.ang2pix(self.nside, theta, phi) - lnprob[i] = np.log(self.prob[pixel] / self.pixel_area) + xpx.at(lnprob, i).set(xp.log(xp.asarray(self.prob[pixel] / self.pixel_area))) if self.distance: self.update_distance(pixel) - lnprob[i] += np.log(self.distance_pdf(dist) * dist ** 2) - lnprob[outbounds] = -np.inf + lnprob = xpx.at(lnprob, i).set( + lnprob[i] + xp.log(xp.asarray(self.distance_pdf(dist) * dist ** 2)) + ) + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): diff --git a/bilby/gw/sampler/proposal.py b/bilby/gw/sampler/proposal.py index 79e1ec92c..2ac84687e 100644 --- a/bilby/gw/sampler/proposal.py +++ b/bilby/gw/sampler/proposal.py @@ -13,7 +13,7 @@ class SkyLocationWanderJump(JumpProposal): def __call__(self, sample, **kwargs): temperature = 1 / kwargs.get('inverse_temperature', 1.0) - sigma = np.sqrt(temperature) / 2 / np.pi + sigma = temperature**0.5 / 2 / np.pi sample['ra'] += random.gauss(0, sigma) sample['dec'] += random.gauss(0, sigma) return super(SkyLocationWanderJump, self).__call__(sample) diff --git a/bilby/gw/source.py b/bilby/gw/source.py index a18a617ce..248738be8 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -1,5 +1,6 @@ import numpy as np +from ..compat.utils import array_module from ..core import utils from ..core.utils import logger from .conversion import bilby_to_lalsimulation_spins @@ -1188,20 +1189,22 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs): dict: Dictionary containing the plus and cross components of the strain. """ - tau = Q / (np.sqrt(2.0) * np.pi * frequency) - temp = Q / (4.0 * np.sqrt(np.pi) * frequency) + xp = array_module(frequency_array) + tau = Q / (2.0**0.5 * np.pi * frequency) + temp = Q / (4.0 * np.pi**0.5 * frequency) fm = frequency_array - frequency fp = frequency_array + frequency - h_plus = ((hrss / np.sqrt(temp * (1 + np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) + - np.exp(-fp**2 * np.pi**2 * tau**2))) + negative_term = xp.exp(-fm**2 * np.pi**2 * tau**2) + positive_term = xp.exp(-fp**2 * np.pi**2 * tau**2) - h_cross = (-1j * (hrss / np.sqrt(temp * (1 - np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) - - np.exp(-fp**2 * np.pi**2 * tau**2))) + h_plus = hrss * np.pi**0.5 * tau / 2 * ( + negative_term + positive_term + ) / (temp * (1 + xp.exp(-Q**2)))**0.5 + + h_cross = -1j * hrss * np.pi**0.5 * tau / 2 * ( + negative_term - positive_term + ) / (temp * (1 - xp.exp(-Q**2)))**0.5 return {'plus': h_plus, 'cross': h_cross} @@ -1284,12 +1287,13 @@ def supernova_pca_model( dict: The plus and cross polarizations of the signal """ + xp = array_module(frequency_array) principal_components = kwargs["realPCs"] + 1j * kwargs["imagPCs"] coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5] - strain = np.sum( - [coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)], + strain = xp.sum( + xp.asarray([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)]), axis=0 ) diff --git a/bilby/gw/time.py b/bilby/gw/time.py new file mode 100644 index 000000000..3c115646b --- /dev/null +++ b/bilby/gw/time.py @@ -0,0 +1,211 @@ +import numpy as np +from plum import dispatch + +from ..compat.utils import array_module + + +__all__ = [ + "datetime", + "gps_time_to_utc", + "greenwich_mean_sidereal_time", + "greenwich_sidereal_time", + "n_leap_seconds", + "utc_to_julian_day", + "LEAP_SECONDS", +] + + +class datetime: + """ + A barebones datetime class for use in the GPS to GMST conversion. + """ + + def __init__( + self, + year: int = 0, + month: int = 0, + day: int = 0, + hour: int = 0, + minute: int = 0, + second: float = 0, + ): + self.year = year + self.month = month + self.day = day + self.hour = hour + self.minute = minute + self.second = second + + def __repr__(self): + return f"{self.year}-{self.month}-{self.day} {self.hour}:{self.minute}:{self.second}" + + def __add__(self, other): + """ + Add two datetimes together. + Note that this does not handle overflow and can lead to unphysical + values for the various attributes. + """ + return datetime( + self.year + other.year, + self.month + other.month, + self.day + other.day, + self.hour + other.hour, + self.minute + other.minute, + self.second + other.second, + ) + + @property + def julian_day(self): + return ( + 367 * self.year + - 7 * (self.year + (self.month + 9) // 12) // 4 + + 275 * self.month // 9 + + self.day + + self.second / SECONDS_PER_DAY + + JULIAN_GPS_EPOCH + ) + + +GPS_EPOCH = datetime(1980, 1, 6, 0, 0, 0) +JULIAN_GPS_EPOCH = 1721013.5 +EPOCH_J2000_0_JD = 2451545.0 +DAYS_PER_CENTURY = 36525.0 +SECONDS_PER_DAY = 86400.0 +LEAP_SECONDS = [ + 46828800, + 78364801, + 109900802, + 173059203, + 252028804, + 315187205, + 346723206, + 393984007, + 425520008, + 457056009, + 504489610, + 551750411, + 599184012, + 820108813, + 914803214, + 1025136015, + 1119744016, + 1167264017, +] + + +@dispatch +def gps_time_to_utc(gps_time): + """ + Convert GPS time to UTC. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + datetime + UTC time. + """ + return GPS_EPOCH + datetime(second=gps_time - n_leap_seconds(gps_time)) + + +@dispatch +def greenwich_mean_sidereal_time(gps_time): + """ + Calculate the Greenwich Mean Sidereal Time. + + This is a thin wrapper around :py:func:`greenwich_sidereal_time` with the + equation of the equinoxes set to zero. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + float + Greenwich Mean Sidereal Time in radians. + """ + return greenwich_sidereal_time(gps_time, gps_time * 0) + + +@dispatch +def greenwich_sidereal_time(gps_time, equation_of_equinoxes): + """ + Calculate the Greenwich Sidereal Time. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + equation_of_equinoxes : float + Equation of the equinoxes in seconds. + + Returns + ------- + float + """ + julian_day = utc_to_julian_day(gps_time_to_utc(gps_time // 1)) + t_hi = (julian_day - EPOCH_J2000_0_JD) / DAYS_PER_CENTURY + t_lo = (gps_time % 1) / (DAYS_PER_CENTURY * SECONDS_PER_DAY) + + t = t_hi + t_lo + + sidereal_time = ( + equation_of_equinoxes + (-6.2e-6 * t + 0.093104) * t**2 + 67310.54841 + ) + sidereal_time += 8640184.812866 * t_lo + sidereal_time += 3155760000.0 * t_lo + sidereal_time += 8640184.812866 * t_hi + sidereal_time += 3155760000.0 * t_hi + + return sidereal_time * 2 * np.pi / SECONDS_PER_DAY + + +@dispatch +def n_leap_seconds(gps_time, leap_seconds): + """ + Calculate the number of leap seconds that have occurred up to a given GPS time. + + Parameters + ---------- + gps_time : float | np.ndarray | int + GPS time in seconds. + leap_seconds : array_like + GPS time of leap seconds. + + Returns + ------- + float + Number of leap seconds + """ + xp = array_module(gps_time) + return xp.sum(gps_time > leap_seconds[:, None], axis=0).squeeze() + + +@dispatch +def n_leap_seconds(gps_time: np.ndarray | float | int): # noqa F811 + xp = array_module(gps_time) + return n_leap_seconds(gps_time, xp.asarray(LEAP_SECONDS)) + + +@dispatch +def utc_to_julian_day(utc_time): + """ + Convert UTC time to Julian day. + + Parameters + ---------- + utc_time : datetime + UTC time. + + Returns + ------- + float + Julian day. + + """ + return utc_time.julian_day diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index f1f4c0291..bbec47fa3 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,18 +5,18 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from bilby_cython.geometry import ( - zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi, -) -from bilby_cython.time import greenwich_mean_sidereal_time +from .geometry import zenith_azimuth_to_theta_phi +from .time import greenwich_mean_sidereal_time +from ..compat.utils import array_module, xp_wrap from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) from ..core.utils.constants import solar_mass -def asd_from_freq_series(freq_data, df): +@xp_wrap +def asd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the ASD from the frequency domain output of gaussian_noise() @@ -32,10 +32,11 @@ def asd_from_freq_series(freq_data, df): array_like: array of real-valued normalized frequency domain ASD data """ - return np.absolute(freq_data) * 2 * df**0.5 + return xp.abs(freq_data) * 2 * df**0.5 -def psd_from_freq_series(freq_data, df): +@xp_wrap +def psd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the PSD from the frequency domain output of gaussian_noise() Calls asd_from_freq_series() and squares the output @@ -52,7 +53,7 @@ def psd_from_freq_series(freq_data, df): array_like: Real-valued normalized frequency domain PSD data """ - return np.power(asd_from_freq_series(freq_data, df), 2) + return asd_from_freq_series(freq_data, df, xp=xp) ** 2 def get_vertex_position_geocentric(latitude, longitude, elevation): @@ -76,14 +77,15 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): array_like: A 3D representation of the geocentric vertex position """ + xp = array_module(latitude) semi_major_axis = 6378137 # for ellipsoid model of Earth, in m semi_minor_axis = 6356752.314 # in m - radius = semi_major_axis**2 * (semi_major_axis**2 * np.cos(latitude)**2 + - semi_minor_axis**2 * np.sin(latitude)**2)**(-0.5) - x_comp = (radius + elevation) * np.cos(latitude) * np.cos(longitude) - y_comp = (radius + elevation) * np.cos(latitude) * np.sin(longitude) - z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * np.sin(latitude) - return np.array([x_comp, y_comp, z_comp]) + radius = semi_major_axis**2 * (semi_major_axis**2 * xp.cos(latitude)**2 + + semi_minor_axis**2 * xp.sin(latitude)**2)**(-0.5) + x_comp = (radius + elevation) * xp.cos(latitude) * xp.cos(longitude) + y_comp = (radius + elevation) * xp.cos(latitude) * xp.sin(longitude) + z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * xp.sin(latitude) + return xp.asarray([x_comp, y_comp, z_comp]) def inner_product(aa, bb, frequency, PSD): @@ -106,11 +108,11 @@ def inner_product(aa, bb, frequency, PSD): psd_interp = PSD.power_spectral_density_interpolated(frequency) # calculate the inner product - integrand = np.conj(aa) * bb / psd_interp + integrand = (aa.conj() * bb / psd_interp).real df = frequency[1] - frequency[0] - integral = np.sum(integrand) * df - return 4. * np.real(integral) + integral = integrand.sum() * df + return 4. * integral def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): @@ -132,9 +134,8 @@ def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): ======= Noise-weighted inner product. """ - - integrand = np.conj(aa) * bb / power_spectral_density - return 4 / duration * np.sum(integrand) + integrand = aa.conj() * bb / power_spectral_density + return 4 / duration * integrand.sum() def matched_filter_snr(signal, frequency_domain_strain, power_spectral_density, duration): @@ -222,34 +223,12 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non """ low_index = int(lower_cut_off / delta_frequency) up_index = int(upper_cut_off / delta_frequency) - integrand = np.conj(signal_a) * signal_b + integrand = signal_a.conj() * signal_b integrand = integrand[low_index:up_index] / power_spectral_density[low_index:up_index] integral = (4 * delta_frequency * integrand) / norm_a / norm_b return sum(integral).real -def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos): - """ - Convert from the 'detector frame' to the Earth frame. - - Parameters - ========== - kappa: float - The zenith angle in the detector frame - eta: float - The azimuthal angle in the detector frame - ifos: list - List of Interferometer objects defining the detector frame - - Returns - ======= - theta, phi: float - The zenith and azimuthal angles in the earth frame. - """ - delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex - return _zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) - - def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): """ Convert from the 'detector frame' to the Earth frame. @@ -945,7 +924,8 @@ def lalsim_SimNeutronStarLoveNumberK2(mass_in_SI, fam): return SimNeutronStarLoveNumberK2(mass_in_SI, fam) -def spline_angle_xform(delta_psi): +@xp_wrap +def spline_angle_xform(delta_psi, *, xp=None): """ Returns the angle in degrees corresponding to the spline calibration parameters delta_psi. @@ -962,7 +942,7 @@ def spline_angle_xform(delta_psi): """ rotation = (2.0 + 1.0j * delta_psi) / (2.0 - 1.0j * delta_psi) - return 180.0 / np.pi * np.arctan2(np.imag(rotation), np.real(rotation)) + return 180.0 / np.pi * xp.arctan2(xp.imag(rotation), xp.real(rotation)) def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=None, xform=None): @@ -1023,7 +1003,8 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= plt.xlim(freq_points.min() - .5, freq_points.max() + 50) -def ln_i0(value): +@xp_wrap +def ln_i0(value, *, xp=None): """ A numerically stable method to evaluate ln(I_0) a modified Bessel function of order 0 used in the phase-marginalized likelihood. @@ -1038,7 +1019,7 @@ def ln_i0(value): array-like: The natural logarithm of the bessel function """ - return np.log(i0e(value)) + np.abs(value) + return xp.log(i0e(value)) + xp.abs(value) def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): @@ -1067,10 +1048,10 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): import lalsimulation return safety * lalsimulation.SimInspiralTaylorF2ReducedSpinChirpTime( - frequency, - mass_1 * solar_mass, - mass_2 * solar_mass, - chi, + float(frequency), + float(mass_1 * solar_mass), + float(mass_2 * solar_mass), + float(chi), -1 ) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index e0f8b4e23..4043caa0d 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -1,3 +1,4 @@ +import array_api_compat as aac import numpy as np from ..core import utils @@ -24,7 +25,8 @@ class WaveformGenerator(object): def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, time_domain_source_model=None, parameters=None, parameter_conversion=None, - waveform_arguments=None): + waveform_arguments=None, use_cache=True, + ): """ The base waveform generator class. @@ -57,6 +59,10 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen Note: the arguments of frequency_domain_source_model (except the first, which is the frequencies at which to compute the strain) will be added to the WaveformGenerator object and initialised to `None`. + use_cache: bool + Whether to attempt caching the waveform between subsequent calls. + This is :code:`True` by default but must be disabled for JIT compilation + with :code:`JAX`. """ self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, @@ -73,9 +79,13 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen self.waveform_arguments = waveform_arguments else: self.waveform_arguments = dict() - if isinstance(parameters, dict): - self.parameters = parameters + if parameters is not None: + logger.warning( + "Setting initial parameters via the 'parameters' argument is " + "deprecated and will be removed in a future release." + ) self._cache = dict(parameters=None, waveform=None, model=None) + self.use_cache = use_cache logger.info(f"Waveform generator instantiated: {self}") def __repr__(self): @@ -102,15 +112,13 @@ def __repr__(self): def frequency_domain_strain(self, parameters=None): """ Wrapper to source_model. - Converts self.parameters with self.parameter_conversion before handing it off to the source model. + Converts parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given. Parameters ========== parameters: dict, optional - Parameters to evaluate the waveform for, this overwrites - `self.parameters`. - If not provided will fall back to `self.parameters`. + If not provided will use the last parameters used. Returns ======= @@ -131,16 +139,14 @@ def frequency_domain_strain(self, parameters=None): def time_domain_strain(self, parameters=None): """ Wrapper to source_model. - Converts self.parameters with self.parameter_conversion before handing it off to the source model. + Converts parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is given. Parameters ========== parameters: dict, optional - Parameters to evaluate the waveform for, this overwrites - `self.parameters`. - If not provided will fall back to `self.parameters`. + If not provided will use the last parameters used. Returns ======= @@ -161,9 +167,15 @@ def time_domain_strain(self, parameters=None): def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, transformed_model_data_points, parameters): if parameters is None: - parameters = self.parameters - if parameters == self._cache['parameters'] and self._cache['model'] == model and \ - self._cache['transformed_model'] == transformed_model: + parameters = self._cache.get('parameters', None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") + if ( + self.use_cache + and parameters == self._cache.get('parameters', None) + and self._cache['model'] == model + and self._cache['transformed_model'] == transformed_model + ): return self._cache['waveform'] else: self._cache['parameters'] = parameters.copy() @@ -190,7 +202,7 @@ def _strain_from_transformed_model( transformed_model_data_points, transformed_model, parameters ) - if isinstance(transformed_model_strain, np.ndarray): + if aac.is_array_api_obj(transformed_model_strain): return transformation_function(transformed_model_strain, self.sampling_frequency) model_strain = dict() @@ -507,7 +519,9 @@ def frequency_domain_strain(self, parameters): from lalsimulation.gwsignal import GenerateFDWaveform if parameters is None: - parameters = self.parameters + parameters = self._cache.get("parameters", None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") hpc = _try_waveform_call( GenerateFDWaveform, @@ -541,7 +555,9 @@ def time_domain_strain(self, parameters): from lalsimulation.gwsignal import GenerateTDWaveform if parameters is None: - parameters = self.parameters + parameters = self._cache.get("parameters", None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") hpc = _try_waveform_call( GenerateTDWaveform, diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 4b691845e..629324a09 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -3,6 +3,7 @@ import numpy as np +from ..compat.utils import array_module from ..core.likelihood import Likelihood, _fallback_to_parameters from .model import Model from ..core.prior import PriorDict @@ -29,11 +30,13 @@ class HyperparameterLikelihood(Likelihood): the sampling prior and the hyperparameterised model. max_samples: int, optional Maximum number of samples to use from each set. + xp: module + The array backend to use for the data. """ def __init__(self, posteriors, hyper_prior, sampling_prior=None, - log_evidences=None, max_samples=1e100): + log_evidences=None, max_samples=1e100, xp=np): if not isinstance(hyper_prior, Model): hyper_prior = Model([hyper_prior]) if sampling_prior is None: @@ -53,7 +56,7 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, self.max_samples = max_samples super(HyperparameterLikelihood, self).__init__() - self.data = self.resample_posteriors() + self.data = self.resample_posteriors(xp=xp) self.n_posteriors = len(self.posteriors) self.samples_per_posterior = self.max_samples self.samples_factor =\ @@ -61,10 +64,11 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, def log_likelihood_ratio(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) / - self.data['prior'], axis=-1))) + probs = self.hyper_prior.prob(self.data, **parameters) + xp = array_module(probs) + log_l = xp.sum(xp.log(xp.sum(probs / self.data['prior'], axis=-1))) log_l += self.samples_factor - return np.nan_to_num(log_l) + return xp.nan_to_num(log_l) def noise_log_likelihood(self): return self.evidence_factor @@ -72,7 +76,7 @@ def noise_log_likelihood(self): def log_likelihood(self, parameters=None): return self.noise_log_likelihood() + self.log_likelihood_ratio(parameters=parameters) - def resample_posteriors(self, max_samples=None): + def resample_posteriors(self, max_samples=None, xp=np): """ Convert list of pandas DataFrame object to dict of arrays. @@ -107,5 +111,5 @@ def resample_posteriors(self, max_samples=None): for key in data: data[key].append(temp[key]) for key in data: - data[key] = np.array(data[key]) + data[key] = xp.asarray(data[key]) return data diff --git a/docs/array_api.rst b/docs/array_api.rst new file mode 100644 index 000000000..8ce1cc043 --- /dev/null +++ b/docs/array_api.rst @@ -0,0 +1,550 @@ +===================== +Array API Support +===================== + +Bilby now supports the Python `Array API Standard `_, +enabling the use of different array backends (NumPy, JAX, CuPy, etc.) for improved performance +and hardware acceleration. This page describes how to use this functionality and how it works internally. + +For Users and Downstream Developers +==================================== + +Overview +-------- + +The Array API support allows you to use different array libraries with Bilby seamlessly. +This can significantly improve performance, especially when using hardware accelerators like GPUs +or when you need automatic differentiation capabilities. + +**Key principle**: In most cases, you don't need to explicitly specify which array backend to use. +Bilby automatically detects the array type you're working with and uses the appropriate backend. +Simply pass JAX arrays, CuPy arrays, or NumPy arrays to prior methods, and Bilby handles the rest. + +Supported Backends +------------------ + +Bilby is currently tested with the following array backends: + +- **NumPy** (default): Standard CPU-based computations +- **JAX**: GPU/TPU acceleration and automatic differentiation +- **PyTorch**: GPU acceleration and deep learning integration. + :code:`PyTorch` support is not complete, for example, functionality + requiring interpolation is not available. + +While :code:`Bilby` should be compatible with other Array API compliant libraries, +these are not currently tested or officially supported. +If you notice any issues when using other backends, +please report them on the `Bilby GitHub repository `. + +Using Different Array Backends +------------------------------- + +Basic Prior Usage (Automatic Detection) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The array backend is automatically detected from your input arrays. You typically don't need +to specify the ``xp`` parameter: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import numpy as np + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + + # Using JAX - backend automatically detected + val_jax = jnp.array([0.5, 1.5, 2.5]) + prob_jax = prior.prob(val_jax) # Returns JAX array + + # Using NumPy - backend automatically detected + val_np = np.array([0.5, 1.5, 2.5]) + prob_np = prior.prob(val_np) # Returns NumPy array + +Sampling with Array Backends (Explicit xp Required) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When sampling from priors, you **must** explicitly specify the array backend using the ``xp`` parameter, +as there's no input array to infer the backend from: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + samples = prior.sample(size=1000, xp=jnp) # Returns JAX array + + # Or with NumPy (default) + samples_np = prior.sample(size=1000) # Or explicitly: xp=np + +.. note:: + + Currently, prior sampling is done by first generating uniform samples in [0, 1] + using :code:`NumPy`, then converting to the desired backend. + In future releases, this may be altered to generate samples directly in the specified backend. + +Prior Dictionaries +~~~~~~~~~~~~~~~~~~ + +Prior dictionaries work the same way - automatic detection for most methods, explicit ``xp`` for sampling: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + + priors = bilby.core.prior.PriorDict({ + 'x': bilby.core.prior.Uniform(0, 100), + 'y': bilby.core.prior.Uniform(0, 1) + }) + + # Sampling requires explicit xp + samples = priors.sample(size=1000, xp=jnp) + + # Evaluation automatically detects backend from input + theta = jnp.array([50.0, 0.5]) + prob = priors.prob(samples) # Automatically uses JAX + +Core Likelihoods and Sampling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Core :code:`Bilby` likelihoods are compatible with the Array API. +When using :code:`JAX` arrays, you can take advantage of :code:`JAX`'s JIT compilation and automatic differentiation. +For :code:`JAX`-compatible samplers (e.g., :code:`numpyro`), +you can pass any :code:`JAX`-compatible :code:`Bilby` likelihood directly. +For non-:code:`JAX` samplers, you should wrap your likelihood with the +:code:`bilby.compat.jax.JittedLikelihood` class to enable JIT compilation. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + from bilby.compat.jax import JittedLikelihood + + class MyLikelihood(bilby.Likelihood): + def log_likelihood(self, parameters): + # model returns a JAX array if passed a dictionary of JAX arrays + return -0.5 * xp.sum((self.data - model(parameters))**2) + + data = jnp.array([...]) # Your data as a JAX array + + priors = bilby.core.prior.PriorDict({ + 'param1': bilby.core.prior.Uniform(0, 10), + 'param2': bilby.core.prior.Uniform(-5, 5) + }) + + likelihood = MyLikelihood(data) + + # call the likelihood once in case any initial setup is needed + likelihood.log_likelihood(priors.sample()) + + # Wrap with JittedLikelihood for JAX + jitted_likelihood = JittedLikelihood(likelihood) + + # call the jitted likelihood once to trigger JIT compilation + # the JittedLikelihood automatically converts the parameters + # to JAX arrays + jitted_likelihood.log_likelihood(priors.sample()) + + # Use with a JAX-incompatible sampler + sampler = bilby.run_sampler(likelihood=jitted_likelihood, ...) + +Gravitational-Wave Likelihoods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :code:`Bilby` implementation of gravitational-wave likelihood is compatible with the Array API, +however this requires access to waveform models that support the provided array backend. +The desired array backend must be explicitly specified for the data, +using :code:`bilby.gw.detector.networks.InterferometerList.set_array_backend`. +Below is an example using the :code:`ripplegw` package for waveform generation. +Here, an injection is performed using the standard :code:`LALSimulation` waveform generator, +and the analysis is then performed using the JIT-compiled likelihood. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import ripplegw + + priors = bilby.gw.prior.BBHPriorDict() + priors["geocent_time"] = bilby.core.prior.Uniform(1126259462.4, 1126259462.6) + injection_parameters = priors.sample() + + # Create interferometers and inject signal using standard waveform generator + ifos = bilby.gw.detector.networks.InterferometerList(['H1', 'L1']) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=2048, + duration=4, + start_time=injection_parameters["geocent_time"] - 2 + ) + injection_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + waveform_arguments={"approximant": "IMRPhenomXODE"} + ) + ifos.inject_signal(parameters=injection_parameters, waveform_generator=injection_wfg) + + # set the array backend after the injection + ifos.set_array_backend(jnp) + + ripple_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=ripplegw.get_fd_waveform + ) + + # Create gravitational-wave likelihood + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=ripple_wfg, + priors=priors, + phase_marginalization=True, + ) + # call the likelihood once to do some initial setup + # this is needed for the gravitational-wave transient likelihoods + likelihood.log_likelihood_ratio(priors.sample()) + + # Wrap with JittedLikelihood for JAX and JIT compile + jitted_likelihood = bilby.compat.jax.JittedLikelihood(likelihood) + jitted_likelihood.log_likelihood_ratio(priors.sample()) + +.. note:: + + All of the likelihood marginalizations implemented in :code:`Bilby` are compatible with the Array API. + However, there is currently a performance issue with the distance marginalized likelihood + using the :code:`JAX` backend. + +.. warning:: + + Some array backends (notably :code:`torch`) are more picky than others about data types. + For maximal consistency, try to consistently pass zero-dimensional arrays rather than :code:`Python` + scalars, e.g., :code:`torch.array(1.0)` instead of :code:`1.0`. + +Performance Considerations +-------------------------- + +**When to use JAX:** + +- GPU/TPU acceleration is available +- You need automatic differentiation +- Working with large datasets or many parameters +- Repeated evaluations benefit from JIT compilation + +**When to use NumPy:** + +- Simple CPU-based computations +- Small datasets +- Maximum compatibility +- Debugging (easier to inspect values) + +**Best Practices:** + +1. Let Bilby detect the array backend automatically - only specify ``xp`` when sampling +2. Use array backend consistently throughout your analysis +3. Avoid mixing array types in the same computation +4. For JAX, consider using ``jax.jit`` for repeated computations +5. Profile your code to ensure the chosen backend provides benefits +6. If you find :code:`xp_wrap` is a bottleneck in your code, you can explicitly pass + :code:`xp` to the function/method to skip the automatic backend detection step. + +Bilby and JIT compilation +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Currently, Bilby functions are not JIT-compiled by default. +Additionally, many Bilby types are not defined as :code:`JAX` :code:`PyTrees`, +and so cannot be passed as arguments to JIT-compiled functions. +We plan to support JIT-compilation for at least some Bilby types in future releases. + +Custom Priors with Array API +----------------------------- + +When creating custom priors, ensure they support the Array API: + +Example Implementation +~~~~~~~~~~~~~~~~~~~~~~ + +Always include the ``xp`` parameter with a default value: + +.. code-block:: python + + from bilby.core.prior import Prior + + class MyCustomPrior(Prior): + def __init__(self, parameter, **kwargs): + super().__init__(**kwargs) + self.parameter = parameter + + def rescale(self, val, *, xp=None): + """Rescale method with xp parameter.""" + return self.minimum + val * (self.maximum - self.minimum) * self.parameter + + def prob(self, val, *, xp=None): + """Probability method with xp parameter.""" + in_range = (val >= self.minimum) & (val <= self.maximum) + return in_range / (self.maximum - self.minimum) * self.parameter + +The ``xp`` parameter should: + +- Be a keyword-only argument (after ``*``) +- Have a default value (``None`` if method is decorated with ``@xp_wrap``, ``np`` otherwise) +- Be passed through to any array operations if used directly + +**Note**: Users of your custom prior won't need to pass ``xp`` explicitly for evaluation methods - +it will be automatically inferred from their input arrays. They only need to specify ``xp`` when sampling. + +Using the :code:`xp_wrap`` Decorator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For methods that perform array operations, use the ``@xp_wrap`` decorator: + +.. code-block:: python + + from bilby.core.prior import Prior + from bilby.compat.utils import xp_wrap + import numpy as np + + class MyCustomPrior(Prior): + @xp_wrap + def prob(self, val, *, xp=None): + """The decorator handles xp=None automatically.""" + return xp.exp(-val) / self.normalization * self.is_in_prior_range(val) + + @xp_wrap + def ln_prob(self, val, *, xp=None): + """Works with logarithmic operations.""" + return -val - xp.log(self.normalization) + xp.log(self.is_in_prior_range(val)) + +The ``@xp_wrap`` decorator: + +- Automatically provides the appropriate array module when ``xp=None`` +- Infers the array backend from input arrays when they are :code:`JAX`/:code:`CuPy`/:code:`PyTorch` arrays +- Falls back to NumPy when the input is a standard Python type or NumPy array +- Handles the conversion seamlessly so users don't need to specify ``xp`` + +Missing functionality +--------------------- + +The most significant missing functionality is the lack of a consistent random number generation +interface across different array backends. +Currently, all random calls use :code:`numpy.random` with the seed specified as described in :doc:`rng`. +This means that functionality like prior sampling and generating noise realizations in gravitational-wave +detectors will not be :code:`JIT`-compatible. + +For Bilby Developers +===================== + +Architecture Overview +--------------------- + +The Array API support in Bilby is built around several key components: + +1. **The xp parameter**: A keyword-only parameter added to prior methods +2. **The @xp_wrap decorator**: Handles array module selection and injection +4. **Compatibility utilities**: Helper functions for array module detection + +Core Changes to Prior Base Class +--------------------------------- + +The ``Prior`` base class in ``bilby/core/prior/base.py`` includes these key changes: + +Method Signature Pattern +~~~~~~~~~~~~~~~~~~~~~~~~ + +All array-processing methods in prior classes follow this pattern: + +**For methods with @xp_wrap decorator**: + +.. code-block:: python + + @xp_wrap + def prob(self, val, *, xp=None): + """Method that uses xp for array operations.""" + return xp.some_operation(val) * self.is_in_prior_range(val) + +**For methods without @xp_wrap (that use xp directly)**: + +.. code-block:: python + + def sample(self, size=None, *, xp=np): + """Method that uses xp but isn't wrapped.""" + return xp.asarray(random.rng.uniform(0, 1, size)) + +Key rules: + +- ``xp`` is always keyword-only (after ``*``) +- Methods with ``@xp_wrap`` use ``xp=None`` as default +- Methods without ``@xp_wrap`` that use ``xp`` use ``xp=np`` as default +- Methods that don't use ``xp`` have ``xp=None`` as default + +The :code:`@xp_wrap`` Decorator +------------------------------- + +Located in ``bilby/compat/utils.py``, this decorator: + +1. **Inspects input arguments** to determine the array module in use +2. **Provides the appropriate xp** when ``xp=None`` +3. **Maintains backward compatibility** with code that doesn't pass ``xp`` + +Example implementation pattern: + +.. code-block:: python + + from bilby.compat.utils import xp_wrap + + @xp_wrap + def my_function(val, *, xp=None): + # When called: + # - If xp=None, decorator infers from val + # - If xp is provided, uses that + # - Returns results in the same array type as input + return xp.exp(val) / xp.mean(val) + +Testing Array API Support +------------------------- + +Test Structure +~~~~~~~~~~~~~~ + +When appropriate, tests should verify functionality across different +backends using the ``array_backend`` marker: + +.. code-block:: python + + @pytest.mark.array_backend + @pytest.mark.usefixtures("xp_class") + class TestMyPrior: + def test_prob(self): + prior = MyPrior() + val = self.xp.asarray([0.5, 1.5, 2.5]) + # No need to pass xp - automatically detected + prob = prior.prob(val) + assert self.xp.all(prob >= 0) + assert aac.get_namespace(prob) == self.xp + + def test_sample(self): + prior = MyPrior() + # Sampling requires explicit xp + samples = prior.sample(size=100, xp=self.xp) + assert aac.get_namespace(samples) == self.xp + +The array_backend Marker +~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``@pytest.mark.array_backend`` marker is used to indicate that a test or test class should be run +with multiple array backends. When you run pytest with the ``--array-backend`` flag, only tests marked +with ``array_backend`` will be executed with that specific backend. + +Without the marker, tests run with the default NumPy backend only. With the marker: + +- Tests are parametrized to run with different backends +- The ``xp_class`` fixture is available, providing access to the array module via ``self.xp`` +- Tests verify that code works correctly regardless of the array backend + +Running Tests with Different Backends +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the ``--array-backend`` flag to test with specific backends:: + + # Test with NumPy (default) + pytest test/core/prior/analytical_test.py + + # Test with JAX backend + pytest --array-backend jax test/core/prior/analytical_test.py + + # Test with CuPy backend + pytest --array-backend cupy test/core/prior/analytical_test.py + +Bilby automatically sets ``SCIPY_ARRAY_API=1`` on import, so you don't need to set this +environment variable manually. The ``--array-backend`` flag controls which backend the +``xp_class`` fixture provides to your tests. + +Migration Guide from Previous Versions +-------------------------------------- + +Key Differences +~~~~~~~~~~~~~~~ + +1. **Method signatures changed**: All prior methods now include ``xp`` parameter +2. **Decorator added**: Many methods now use ``@xp_wrap`` +3. **Default values differ**: Methods with ``@xp_wrap`` use ``xp=None``, others use ``xp=np`` +4. **Validation added**: Custom priors are checked for ``xp`` support + +Best Practices for Contributors +-------------------------------- + +When adding or modifying prior methods: + +1. **Always include xp parameter** in prob, ln_prob, rescale, cdf, sample methods +2. **Use @xp_wrap decorator** for methods doing array operations +3. **Set correct default**: ``xp=None`` with decorator, ``xp=np`` without (for methods that use xp directly) +4. **Pass xp through**: When calling other methods, pass ``xp=xp`` +5. **Test with multiple backends**: Use ``@pytest.mark.array_backend`` and test with ``--array-backend jax`` +6. **Document xp parameter**: Note it in docstrings, but emphasize it's usually auto-detected +7. **Use array module functions**: Use ``xp.function()`` not ``np.function()`` in wrapped methods + +Handling Array Updates with :code:`array_api_extra.at`` +------------------------------------------------------- + +One key difference between array backends is how they handle array updates. +NumPy allows in-place modification of array slices, +while JAX requires functional updates since arrays are immutable. +The ``array_api_extra.at`` function provides a unified interface for array updates across backends. + +Usage Examples +~~~~~~~~~~~~~~ + +**Conditional update**: + +.. code-block:: python + + @xp_wrap + def conditional_update(vals, *, xp=None): + """Update array elements where mask is True.""" + arr = vals**2 + mask = arr > 0.5 + # Instead of: arr[mask] = value + arr = xpx.at(arr)[mask].set(value) + return arr + +**Increment operation**: + +.. code-block:: python + + @xp_wrap + def increment_slice(arr, *, xp=None): + """Add values to a slice of an array.""" + # Instead of: arr[2:5] += values + arr = xpx.at(arr)[2:5].add(values) + return arr + +Available Operations +~~~~~~~~~~~~~~~~~~~~ + +The ``at`` function supports several operations: + +- ``set(values)``: Replace values at specified indices +- ``add(values)``: Add values to specified indices +- ``multiply(values)``: Multiply specified indices by values +- ``min(values)``: Take element-wise minimum +- ``max(values)``: Take element-wise maximum + +Important Notes +~~~~~~~~~~~~~~~ + +1. **Return value**: Always use the returned array. The operation may create a new array (JAX) or modify in-place (NumPy). + +2. **Import**: Import ``array_api_extra`` at the module level: + +.. code-block:: python + + import array_api_extra as xpx + +Further Resources +----------------- + +- `Array API Standard `_ +- `JAX Documentation `_ +- `array-api-compat Package `_ +- `array-api-extra Package `_ diff --git a/docs/index.txt b/docs/index.txt index ff6e12c85..d8fabb550 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -16,6 +16,7 @@ Welcome to bilby's documentation! prior likelihood samplers + array_api dynesty-guide bilby-mcmc-guide rng diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py new file mode 100644 index 000000000..56b1b4d3a --- /dev/null +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python +""" +Tutorial to demonstrate running parameter estimation on a reduced parameter +space for an injected signal. + +This example estimates the masses using a uniform prior in both component masses +and distance using a uniform in comoving volume prior on luminosity distance +between luminosity distances of 100Mpc and 5Gpc, the cosmology is Planck15. + +We optionally use ripple waveforms and a JIT-compiled likelihood. +""" +import os + +# Set OMP_NUM_THREADS to stop lalsimulation taking over my computer +os.environ["OMP_NUM_THREADS"] = "1" + +import bilby +import jax +import jax.numpy as jnp +import numpy as np +from bilby.compat.jax import JittedLikelihood +from ripple.waveforms import IMRPhenomPv2 + +jax.config.update("jax_enable_x64", True) + + +def bilby_to_ripple_spins( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + a_1, + a_2, +): + """ + A simplified spherical to cartesian spin conversion function. + This is not equivalent to the method used in `bilby.gw.conversion` + which comes from `lalsimulation` and is not `JAX` compatible. + """ + iota = theta_jn + spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) + spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) + spin_1z = a_1 * jnp.cos(tilt_1) + spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) + spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) + spin_2z = a_2 * jnp.cos(tilt_2) + return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z + + +def ripple_bbh( + frequency, + mass_1, + mass_2, + luminosity_distance, + theta_jn, + phase, + a_1, + a_2, + tilt_1, + tilt_2, + phi_12, + phi_jl, + **kwargs, +): + """ + Source function wrapper to ripple's IMRPhenomPv2 waveform generator. + This function cannot be jitted directly as the Bilby waveform generator + relies on inspecting the function signature. + + Parameters + ---------- + frequency: jnp.ndarray + Frequencies at which to compute the waveform. + mass_1: float | jnp.ndarray + Mass of the primary component in solar masses. + mass_2: float | jnp.ndarray + Mass of the secondary component in solar masses. + luminosity_distance: float | jnp.ndarray + Luminosity distance to the source in Mpc. + theta_jn: float | jnp.ndarray + Angle between total angular momentum and line of sight in radians. + phase: float | jnp.ndarray + Phase at coalescence in radians. + a_1: float | jnp.ndarray + Dimensionless spin magnitude of the primary component. + a_2: float | jnp.ndarray + Dimensionless spin magnitude of the secondary component. + tilt_1: float | jnp.ndarray + Tilt angle of the primary component spin in radians. + tilt_2: float | jnp.ndarray + Tilt angle of the secondary component spin in radians. + phi_12: float | jnp.ndarray + Azimuthal angle between the two spin vectors in radians. + phi_jl: float | jnp.ndarray + Azimuthal angle of the total angular momentum vector in radians. + **kwargs + Additional keyword arguments. Must include 'minimum_frequency'. + + Returns + ------- + dict + Dictionary containing the plus and cross polarizations of the waveform. + """ + iota, *cartesian_spins = bilby_to_ripple_spins( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 + ) + frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) + theta = jnp.array( + [ + mass_1, + mass_2, + *cartesian_spins, + luminosity_distance, + jnp.array(0.0), + phase, + iota, + ] + ) + wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) + hp, hc = wf_func(frequencies, theta, jnp.array(20.0)) + return dict(plus=hp, cross=hc) + + +def main(): + # Set the duration and sampling frequency of the data segment that we're + # going to inject the signal into + duration = 64.0 + sampling_frequency = 2048.0 + minimum_frequency = 20.0 + duration = jnp.array(duration) + sampling_frequency = jnp.array(sampling_frequency) + minimum_frequency = jnp.array(minimum_frequency) + + # Specify the output directory and the name of the simulation. + outdir = "outdir" + label = "jax_fast_tutorial" + + # Set up a random seed for result reproducibility. This is optional! + bilby.core.utils.random.seed(88170235) + + priors = bilby.gw.prior.BBHPriorDict() + injection_parameters = priors.sample() + injection_parameters["geocent_time"] = 1000000000.0 + injection_parameters["luminosity_distance"] = 400.0 + del priors["ra"], priors["dec"] + priors["zenith"] = bilby.core.prior.Cosine() + priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) + priors["L1_time"] = bilby.core.prior.Uniform( + injection_parameters["geocent_time"] - 0.1, + injection_parameters["geocent_time"] + 0.1, + ) + + # Fixed arguments passed into the source model + waveform_arguments = dict( + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=minimum_frequency, + ) + + # Create the waveform_generator using a LAL BinaryBlackHole source function + waveform_generator = bilby.gw.WaveformGenerator( + duration=duration, + sampling_frequency=sampling_frequency, + frequency_domain_source_model=ripple_bbh, + waveform_arguments=waveform_arguments, + use_cache=False, + ) + + # Set up interferometers. In this case we'll use two interferometers + # (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design + # sensitivity + ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=sampling_frequency, + duration=duration, + start_time=injection_parameters["geocent_time"] - duration + 2, + ) + ifos.inject_signal( + waveform_generator=waveform_generator, + parameters=injection_parameters, + raise_error=False, + ) + ifos.set_array_backend(jnp) + + # Initialise the likelihood by passing in the interferometer data (ifos) and + # the waveform generator + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=waveform_generator, + priors=priors, + phase_marginalization=True, + distance_marginalization=True, + reference_frame=ifos, + time_reference="L1", + ) + # Do an initial likelihood evaluation to trigger any internal setup + likelihood.log_likelihood_ratio(priors.sample()) + # Wrap the likelihood with the JittedLikelihood to JIT compile the likelihood + # evaluation + likelihood = JittedLikelihood(likelihood) + # Evaluate the likelihood once to trigger the JIT compilation, this will take + # a few seconds as compiling the waveform takes some time + likelihood.log_likelihood_ratio(priors.sample()) + + # use the log_compiles context so we can make sure there aren't recompilations + # inside the sampling loop + with jax.log_compiles(): + result = bilby.run_sampler( + likelihood=likelihood, + priors=priors, + sampler="dynesty", + nlive=100, + sample="acceptance-walk", + naccept=5, + injection_parameters=injection_parameters, + outdir=outdir, + label=label, + npool=None, + save="hdf5", + rseed=np.random.randint(0, 100000), + ) + + # Make a corner plot. + result.plot_corner() + + +if __name__ == "__main__": + main() diff --git a/jax_requirements.txt b/jax_requirements.txt new file mode 100644 index 000000000..b325586a3 --- /dev/null +++ b/jax_requirements.txt @@ -0,0 +1,2 @@ +interpax +jax \ No newline at end of file diff --git a/optional_requirements.txt b/optional_requirements.txt index c10d7908b..f0f2205f6 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,5 +1,6 @@ celerite george +parameterized plotly pytest-requires pytest-rerunfailures diff --git a/pyproject.toml b/pyproject.toml index 145d905d1..b2ccdb444 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,11 +114,13 @@ addopts = [ packages = [ "bilby", "bilby.bilby_mcmc", + "bilby.compat", "bilby.core", "bilby.core.prior", "bilby.core.sampler", "bilby.core.utils", "bilby.gw", + "bilby.gw.compat", "bilby.gw.detector", "bilby.gw.eos", "bilby.gw.likelihood", @@ -133,11 +135,13 @@ dependencies = {file = ["requirements.txt"]} [tool.setuptools.dynamic.optional-dependencies] all = {file = [ "gw_requirements.txt", + "jax_requirements.txt", "mcmc_requirements.txt", "sampler_requirements.txt", "optional_requirements.txt" ]} gw = {file = ["gw_requirements.txt"]} +jax = {file = ["jax_requirements.txt"]} mcmc = {file = ["mcmc_requirements.txt"]} [tool.setuptools.package-data] diff --git a/requirements.txt b/requirements.txt index b045db212..f1a91484d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ -bilby.cython>=0.3.0 +# see https://github.com/data-apis/array-api-compat/pull/341 +array_api_compat>=1.13 +array_api_extra dynesty>=2.0.1 emcee corner @@ -11,4 +13,4 @@ dill tqdm h5py attrs -importlib-metadata>=3.6; python_version < '3.10' +plum-dispatch diff --git a/test/conftest.py b/test/conftest.py index d08c38604..83e7a89a6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,6 @@ +import importlib + +import array_api_compat as aac import pytest @@ -5,15 +8,69 @@ def pytest_addoption(parser): parser.addoption( "--skip-roqs", action="store_true", default=False, help="Skip all tests that require ROQs" ) + parser.addoption( + "--array-backend", + default=None, + help="Which array to use for testing", + ) def pytest_configure(config): config.addinivalue_line("markers", "requires_roqs: mark a test that requires ROQs") + config.addinivalue_line("markers", "array_backend: mark that a test uses all array backends") def pytest_collection_modifyitems(config, items): if config.getoption("--skip-roqs"): skip_roqs = pytest.mark.skip(reason="Skipping tests that require ROQs") - for item in items: - if "requires_roqs" in item.keywords: - item.add_marker(skip_roqs) + else: + skip_roqs = None + if config.getoption("--array-backend") is not None: + array_only = pytest.mark.skip(reason="Only running backend dependent tests") + else: + array_only = None + for item in items: + if "requires_roqs" in item.keywords and config.getoption("--skip-roqs"): + item.add_marker(skip_roqs) + elif "array_backend" not in item.keywords and array_only is not None: + item.add_marker(array_only) + + +def _xp(request): + backend = request.config.getoption("--array-backend") + match backend: + case None | "numpy": + import numpy as xp + case "jax" | "jax.numpy": + import jax + + jax.config.update("jax_enable_x64", True) + xp = jax.numpy + case "torch": + import torch + # torch starts a lot of threads, so disable this on the first import + # to avoid segfaults + try: + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + torch.set_default_dtype(torch.float64) + except RuntimeError: + pass + xp = torch + case _: + try: + + xp = importlib.import_module(backend) + except ImportError: + raise ValueError(f"Unknown backend for testing: {backend}") + return aac.get_namespace(xp.ones(1)) + + +@pytest.fixture +def xp(request): + return _xp(request) + + +@pytest.fixture(scope="class") +def xp_class(request): + request.cls.xp = _xp(request) diff --git a/test/core/grid_test.py b/test/core/grid_test.py index f14a95134..781077f34 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -1,30 +1,33 @@ import unittest -import numpy as np import shutil import os -from scipy.stats import multivariate_normal + +import array_api_compat as aac +import numpy as np +import pytest import bilby -# set 2D multivariate Gaussian likelihood class MultiGaussian(bilby.Likelihood): - def __init__(self, mean, cov): + # set 2D multivariate Gaussian likelihood + def __init__(self, mean, cov, *, xp=np): super(MultiGaussian, self).__init__() - self.cov = np.array(cov) - self.mean = np.array(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) + self.xp = xp + self.cov = xp.asarray(cov) + self.mean = xp.asarray(mean) + self.sigma = xp.sqrt(xp.diag(self.cov)) @property def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + return -parameters["x0"]**2 / 2 - parameters["x1"]**2 / 2 - np.log(2 * np.pi) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGrid(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(7) @@ -33,7 +36,7 @@ def setUp(self): self.mus = [0.0, 0.0] self.cov = [[1.0, 0.0], [0.0, 1.0]] dim = len(self.mus) - self.likelihood = MultiGaussian(self.mus, self.cov) + self.likelihood = MultiGaussian(self.mus, self.cov, xp=self.xp) # set priors out to +/- 5 sigma self.priors = bilby.core.prior.PriorDict() @@ -61,6 +64,7 @@ def setUp(self): grid_size=self.grid_size, likelihood=self.likelihood, save=True, + xp=self.xp, ) self.grid = grid @@ -140,7 +144,9 @@ def test_max_marginalized_likelihood(self): self.assertEqual(1.0, self.grid.marginalize_likelihood(self.grid.parameter_names[1]).max()) def test_ln_evidence(self): - self.assertAlmostEqual(self.expected_ln_evidence, self.grid.ln_evidence, places=5) + ln_z = self.grid.ln_evidence + self.assertEqual(aac.get_namespace(ln_z), self.xp) + self.assertAlmostEqual(self.expected_ln_evidence, float(ln_z), places=5) def test_fail_grid_size(self): with self.assertRaises(TypeError): @@ -151,6 +157,7 @@ def test_fail_grid_size(self): grid_size=2.3, likelihood=self.likelihood, save=True, + xp=self.xp, ) def test_mesh_grid(self): @@ -165,7 +172,8 @@ def test_grid_integer_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual(tuple(n_points), grid.mesh_grid[0].shape) @@ -179,7 +187,8 @@ def test_grid_dict_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual((n_points["x0"], n_points["x1"]), grid.mesh_grid[0].shape) self.assertEqual(grid.mesh_grid[0][0, 0], self.priors[self.grid.parameter_names[0]].minimum) @@ -196,6 +205,7 @@ def test_grid_from_array(self): priors=self.priors, grid_size=n_points, likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual((len(x0s), len(x1s)), grid.mesh_grid[0].shape) @@ -208,7 +218,7 @@ def test_grid_from_array(self): def test_save_and_load_from_filename(self): filename = os.path.join("outdir", "test_output.json") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(new_grid.parameter_names, self.grid.parameter_names) self.assertEqual(new_grid.n_dims, self.grid.n_dims) @@ -221,7 +231,7 @@ def test_save_and_load_from_filename(self): def test_save_and_load_from_outdir_label(self): self.grid.save_to_file(overwrite=True, outdir="outdir") - new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label") + new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label", xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) @@ -238,7 +248,7 @@ def test_save_and_load_from_outdir_label(self): def test_save_and_load_gzip(self): filename = os.path.join("outdir", "test_output.json.gz") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 3c4c71c26..d64e9a12c 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -1,7 +1,10 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np +import pytest +import array_api_extra as xpx import bilby.core.likelihood from bilby.core.likelihood import ( @@ -51,10 +54,12 @@ def test_meta_data(self): self.assertEqual(self.likelihood.meta_data, meta_data) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalytical1DLikelihood(unittest.TestCase): def setUp(self): - self.x = np.arange(start=0, stop=100, step=1) - self.y = np.arange(start=0, stop=100, step=1) + self.x = self.xp.arange(0, 100, step=1) + self.y = self.xp.arange(0, 100, step=1) def test_func(x, parameter1, parameter2): return parameter1 * x + parameter2 @@ -78,7 +83,7 @@ def test_init_x(self): self.assertTrue(np.array_equal(self.x, self.analytical_1d_likelihood.x)) def test_set_x_to_array(self): - new_x = np.arange(start=0, stop=50, step=2) + new_x = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.x = new_x self.assertTrue(np.array_equal(new_x, self.analytical_1d_likelihood.x)) @@ -98,7 +103,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.analytical_1d_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.analytical_1d_likelihood.y)) @@ -154,17 +159,20 @@ def test_repr(self): self.assertEqual(expected, repr(self.analytical_1d_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGaussianLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.sigma = 0.1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -175,24 +183,21 @@ def tearDown(self): def test_known_sigma(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) - parameters = dict(m=2, c=0) - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) self.assertEqual(likelihood.sigma, self.sigma) def test_known_array_sigma(self): sigma_array = np.ones(self.N) * self.sigma likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma_array) - parameters = dict(m=2, c=0) - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) self.assertTrue(type(likelihood.sigma) == type(sigma_array)) # noqa: E721 self.assertTrue(all(likelihood.sigma == sigma_array)) def test_set_sigma_None(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=None) - parameters = dict(m=2, c=0) self.assertTrue(likelihood.sigma is None) with self.assertRaises(TypeError): - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) def test_sigma_float(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=None) @@ -211,19 +216,27 @@ def test_repr(self): ) self.assertEqual(expected, repr(likelihood)) + def test_return_class(self): + likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) + logl = likelihood.log_likelihood(self.parameters) + self.assertEqual(aac.get_namespace(logl), self.xp) + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestStudentTLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.nu = self.N - 2 self.sigma = 1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -236,8 +249,7 @@ def test_known_sigma(self): likelihood = StudentTLikelihood( self.x, self.y, self.function, self.nu, self.sigma ) - parameters = dict(m=2, c=0) - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) self.assertEqual(likelihood.sigma, self.sigma) def test_set_nu_none(self): @@ -246,21 +258,23 @@ def test_set_nu_none(self): def test_log_likelihood_nu_none(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) - parameters = dict(m=2, c=0) with self.assertRaises(TypeError): - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) def test_log_likelihood_nu_zero(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=0) - parameters = dict(m=2, c=0) with self.assertRaises(ValueError): - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) def test_log_likelihood_nu_negative(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=-1) - parameters = dict(m=2, c=0) with self.assertRaises(ValueError): - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.parameters) + + def test_setting_nu_positive_does_not_change_class_attribute(self): + likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) + likelihood.nu = 98 + self.assertEqual(likelihood.nu, 98) def test_lam(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=0, sigma=0.5) @@ -279,25 +293,28 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPoissonLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.poisson(self.mu, self.N) - self.yfloat = np.copy(self.y) * 1.0 - self.yneg = np.copy(self.y) - self.yneg[0] = -1 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.asarray(np.random.poisson(self.mu, self.N)) + self.yfloat = self.y * 1.0 + self.yneg = self.y * 1.0 + self.yneg = xpx.at(self.yneg, 0).set(-1) def test_function(x, c): return c def test_function_array(x, c): - return np.ones(len(x)) * c + return self.xp.ones(len(x)) * c self.function = test_function self.function_array = test_function_array self.poisson_likelihood = PoissonLikelihood(self.x, self.y, self.function) + self.bad_parameters = dict(c=self.xp.asarray(-2.0)) def tearDown(self): del self.N @@ -311,6 +328,8 @@ def tearDown(self): del self.poisson_likelihood def test_init_y_non_integer(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Torch tensor dtype does not have a 'kind' attribute") with self.assertRaises(ValueError): PoissonLikelihood(self.x, self.yfloat, self.function) @@ -319,23 +338,23 @@ def test_init__y_negative(self): PoissonLikelihood(self.x, self.yneg, self.function) def test_neg_rate(self): - parameters = dict(c=-2) with self.assertRaises(ValueError): - self.poisson_likelihood.log_likelihood(parameters) + self.poisson_likelihood.log_likelihood(self.bad_parameters) def test_neg_rate_array(self): likelihood = PoissonLikelihood(self.x, self.y, self.function_array) - parameters = dict(c=-2) with self.assertRaises(ValueError): - likelihood.log_likelihood(parameters) + likelihood.log_likelihood(self.bad_parameters) def test_init_y(self): - self.assertTrue(np.array_equal(self.y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(self.y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(self.y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.poisson_likelihood.y = new_y - self.assertTrue(np.array_equal(new_y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(new_y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(new_y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_positive_int(self): new_y = 5 @@ -360,25 +379,25 @@ def test_log_likelihood_wrong_func_return_type(self): def test_log_likelihood_negative_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, -2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, -2]) ) with self.assertRaises(ValueError): poisson_likelihood.log_likelihood(dict()) def test_log_likelihood_zero_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, 0]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, 0]) ) self.assertEqual(-np.inf, poisson_likelihood.log_likelihood(dict())) def test_log_likelihood_dummy(self): """ Merely tests if it goes into the right if else bracket """ poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N) + x=self.x, y=self.y, func=lambda x: self.xp.linspace(1, 100, self.N) ) - with mock.patch("numpy.sum") as m: + with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 1 - self.assertEqual(1, poisson_likelihood.log_likelihood(dict())) + self.assertEqual(1, poisson_likelihood.log_likelihood(dict(c=5))) def test_repr(self): likelihood = PoissonLikelihood(self.x, self.y, self.function) @@ -388,26 +407,29 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestExponentialLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.exponential(self.mu, self.N) - self.yneg = np.copy(self.y) - self.yneg[0] = -1.0 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.asarray(np.random.exponential(self.mu, self.N)) + self.yneg = self.y * 1.0 + self.yneg = xpx.at(self.yneg, 0).set(-1.0) def test_function(x, c): return c def test_function_array(x, c): - return c * np.ones(len(x)) + return c * self.xp.ones(len(x)) self.function = test_function self.function_array = test_function_array self.exponential_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=self.function ) + self.bad_parameters = dict(c=self.xp.asarray(-1.0)) def tearDown(self): del self.N @@ -424,19 +446,17 @@ def test_negative_data(self): def test_negative_function(self): likelihood = ExponentialLikelihood(self.x, self.y, self.function) - parameters = dict(c=-1) - self.assertEqual(likelihood.log_likelihood(parameters), -np.inf) + self.assertEqual(likelihood.log_likelihood(self.bad_parameters), -np.inf) def test_negative_array_function(self): likelihood = ExponentialLikelihood(self.x, self.y, self.function_array) - parameters = dict(c=-1) - self.assertEqual(likelihood.log_likelihood(parameters), -np.inf) + self.assertEqual(likelihood.log_likelihood(self.bad_parameters), -np.inf) def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.exponential_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.exponential_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.exponential_likelihood.y)) @@ -461,14 +481,14 @@ def test_set_y_to_negative_float(self): def test_set_y_to_nd_array_with_negative_element(self): with self.assertRaises(ValueError): - self.exponential_likelihood.y = np.array([4.3, -1.2, 4]) + self.exponential_likelihood.y = self.xp.asarray([4.3, -1.2, 4]) def test_log_likelihood_default(self): """ Merely tests that it ends up at the right place in the code """ exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([4.2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([4.2]) ) - with mock.patch("numpy.sum") as m: + with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood(dict())) @@ -479,14 +499,21 @@ def test_repr(self): self.assertEqual(expected, repr(self.exponential_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean = [10, 11, 12] + if self.xp != np: + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean = self.xp.asarray(self.mean, dtype=float) self.likelihood = AnalyticalMultidimensionalCovariantGaussian( mean=self.mean, cov=self.cov ) + self.parameters = {f"x{ii}": 0 for ii in range(len(self.sigma))} def tearDown(self): del self.cov @@ -507,19 +534,34 @@ def test_dim(self): self.assertEqual(3, self.likelihood.dim) def test_log_likelihood(self): - likelihood = AnalyticalMultidimensionalCovariantGaussian(mean=[0], cov=[1]) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + likelihood = AnalyticalMultidimensionalCovariantGaussian( + mean=self.xp.asarray([0.0]), cov=self.xp.asarray([1.0]) + ) + logl = likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), + ) + self.assertEqual(aac.get_namespace(logl), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalBimodalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean_1 = [10, 11, 12] self.mean_2 = [20, 21, 22] + if self.xp != np: + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean_1 = self.xp.asarray(self.mean_1, dtype=float) + self.mean_2 = self.xp.asarray(self.mean_2, dtype=float) self.likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=self.mean_1, mean_2=self.mean_2, cov=self.cov ) + self.parameters = {f"x{ii}": 0 for ii in range(len(self.sigma))} def tearDown(self): del self.cov @@ -547,7 +589,10 @@ def test_log_likelihood(self): likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=[0], mean_2=[0], cov=[1] ) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), + ) class TestJointLikelihood(unittest.TestCase): diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 12892aca1..09942ba07 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -1,16 +1,24 @@ import unittest -import numpy as np +import array_api_compat as aac import bilby +import numpy as np +import pytest +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDiscreteValuesPrior(unittest.TestCase): + def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("DiscreteValues prior is unstable for torch backend") + def test_single_sample(self): values = [1.1, 1.2, 1.3] discrete_value_prior = bilby.core.prior.DiscreteValues(values) in_prior = True for _ in range(1000): - s = discrete_value_prior.sample() + s = discrete_value_prior.sample(xp=self.xp) if s not in values: in_prior = False self.assertTrue(in_prior) @@ -20,7 +28,7 @@ def test_array_sample(self): nvalues = 4 discrete_value_prior = bilby.core.prior.DiscreteValues(values) N = 100000 - s = discrete_value_prior.sample(N) + s = discrete_value_prior.sample(N, xp=self.xp) zeros = np.sum(s == 1.0) ones = np.sum(s == 1.1) twos = np.sum(s == 1.2) @@ -35,60 +43,64 @@ def test_single_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) def test_single_probability_unsorted(self): N = 3 values = [1.1, 300, 2.2] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) + self.assertEqual( + aac.get_namespace(discrete_value_prior.prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.prob([1.1, 2.2, 2.2, 300.0, 200.0]) - == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) - ) - ) + probs = discrete_value_prior.prob(self.xp.asarray([1.1, 2.2, 2.2, 300.0, 200.0])) + self.assertEqual(aac.get_namespace(probs), self.xp) + np.testing.assert_array_equal(np.asarray(probs), np.array([1 / N] * 4 + [0])) def test_single_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.ln_prob(1.1), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(2.2), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(300), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(150), -np.inf) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(1.1)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(2.2)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(300)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(150)), -np.inf) + self.assertEqual( + aac.get_namespace(discrete_value_prior.ln_prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.ln_prob([1.1, 2.2, 2.2, 300, 150]) - == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) - ) - ) + ln_probs = discrete_value_prior.ln_prob(self.xp.asarray([1.1, 2.2, 2.2, 300, 150])) + self.assertEqual(aac.get_namespace(ln_probs), self.xp) + np.testing.assert_array_equal(np.asarray(ln_probs), np.array([-np.log(N)] * 4 + [-np.inf])) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.Categorical(3) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(xp=self.xp) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -97,7 +109,9 @@ def test_array_sample(self): ncat = 4 categorical_prior = bilby.core.prior.Categorical(ncat) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, xp=self.xp) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) @@ -111,37 +125,55 @@ def test_array_sample(self): def test_single_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.prob(0), 1 / N) - self.assertEqual(categorical_prior.prob(1), 1 / N) - self.assertEqual(categorical_prior.prob(2), 1 / N) - self.assertEqual(categorical_prior.prob(0.5), 0) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(1)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(2)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual( + aac.get_namespace(categorical_prior.prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.prob([0, 1, 1, 2, 3]) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]))) + probs = categorical_prior.prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(probs), self.xp) + + self.assertTrue(np.all( + np.asarray(probs) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) + )) def test_single_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.ln_prob(0), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(1), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(2), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(0.5), -np.inf) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(1)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(2)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0.5)), -np.inf) + self.assertEqual( + aac.get_namespace(categorical_prior.ln_prob(self.xp.asarray(0.5))), + self.xp, + ) def test_array_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.ln_prob([0, 1, 1, 2, 3]) == np.array( - [-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]))) + ln_prob = categorical_prior.ln_prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(ln_prob), self.xp) + self.assertTrue(np.all( + np.asarray(ln_prob) == np.array([-np.log(N)] * 4 + [-np.inf]) + )) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWeightedCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.WeightedCategorical(3, [1, 2, 3]) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(xp=self.xp) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -157,7 +189,9 @@ def test_array_sample(self): weights = np.arange(1, ncat + 1) categorical_prior = bilby.core.prior.WeightedCategorical(ncat, weights=weights) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, xp=self.xp) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) cases = 0 for i in categorical_prior.values: case = np.sum(s == i) @@ -170,26 +204,35 @@ def test_single_probability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: + for i in self.xp.asarray(categorical_prior.values): self.assertEqual(categorical_prior.prob(i), weights[i] / np.sum(weights)) - self.assertEqual(categorical_prior.prob(0.5), 0) + prob = categorical_prior.prob(self.xp.asarray(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_probability(self): N = 3 - test_cases = [0, 1, 1, 2, 3] + test_cases = self.xp.asarray([0, 1, 1, 2, 3]) weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) probs = np.arange(1, N + 2) / np.sum(weights) probs[-1] = 0 - self.assertTrue(np.all(categorical_prior.prob(test_cases) == probs[test_cases])) + new = categorical_prior.prob(test_cases) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == probs[test_cases])) def test_single_lnprobability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: - self.assertEqual(categorical_prior.ln_prob(i), np.log(weights[i] / np.sum(weights))) - self.assertEqual(categorical_prior.prob(0.5), 0) + for i in self.xp.asarray(categorical_prior.values): + self.assertEqual( + categorical_prior.ln_prob(self.xp.asarray(i)), + np.log(weights[i] / np.sum(weights)), + ) + prob = categorical_prior.prob(self.xp.asarray(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_lnprobability(self): N = 3 @@ -200,7 +243,9 @@ def test_array_lnprobability(self): ln_probs = np.log(np.arange(1, N + 2) / np.sum(weights)) ln_probs[-1] = -np.inf - self.assertTrue(np.all(categorical_prior.ln_prob(test_cases) == ln_probs[test_cases])) + new = categorical_prior.ln_prob(self.xp.asarray(test_cases)) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == ln_probs[test_cases])) def test_cdf(self): """ @@ -213,11 +258,12 @@ def test_cdf(self): categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) sample = categorical_prior.sample(size=10) - original = np.asarray(sample) - new = np.array(categorical_prior.rescale( + original = self.xp.asarray(sample) + new = self.xp.asarray(categorical_prior.rescale( categorical_prior.cdf(sample) )) np.testing.assert_array_equal(original, new) + self.assertEqual(type(new), type(original)) if __name__ == "__main__": diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index c9b788732..469c53ece 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import Mock +import array_api_compat as aac import numpy as np +import pytest import bilby @@ -56,7 +58,7 @@ def test_base_prob(self): self.assertTrue(np.isnan(self.prior.prob(5))) def test_base_ln_prob(self): - self.prior.prob = lambda val: val + self.prior.prob = lambda val, *, xp=None: val self.assertEqual(np.log(5), self.prior.ln_prob(5)) def test_is_in_prior(self): @@ -139,6 +141,8 @@ def test_prob_inside(self): self.assertEqual(1, self.prior.prob(0.5)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConstraintPriorNormalisation(unittest.TestCase): def setUp(self): self.priors = dict( @@ -154,8 +158,10 @@ def conversion_func(parameters): def test_prob_integrate_to_one(self): keys = ["a", "b", "c"] n_samples = 1000000 - samples = self.priors.sample_subset(keys=keys, size=n_samples) + samples = self.priors.sample_subset(keys=keys, size=n_samples, xp=self.xp) prob = self.priors.prob(samples, axis=0) + self.assertEqual(aac.get_namespace(prob), self.xp) + prob = np.asarray(prob) dm1 = self.priors["a"].maximum - self.priors["a"].minimum dm2 = self.priors["b"].maximum - self.priors["b"].minimum prior_volume = (dm1 * dm2) @@ -169,5 +175,24 @@ def test_prob_integrate_to_one(self): self.assertAlmostEqual(1, integral, delta=7 * sigma_integral) +class TestPriorSubclassWithoutXpWarning(unittest.TestCase): + def test_custom_subclass_without_xp_issues_warning(self): + """Test that a custom prior subclass without xp parameter in rescale method issues a warning.""" + with pytest.warns( + DeprecationWarning, + match=r"rescale.*CustomPriorWithoutXp.*xp.*keyword argument", + ): + # Define a custom prior subclass that doesn't include xp in rescale method + class CustomPriorWithoutXp(bilby.core.prior.Prior): + def rescale(self, val): + """Custom rescale without xp parameter""" + return val * 2 + + prior = CustomPriorWithoutXp(name="custom_prior") + import jax.numpy as jnp + rescaled = prior.rescale(jnp.array([0.1, 0.2, 3])) + self.assertEqual(aac.get_namespace(rescaled), jnp) + + if __name__ == "__main__": unittest.main() diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..5850f5ecf 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -3,9 +3,11 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np import pandas as pd import pickle +import pytest import bilby @@ -172,6 +174,8 @@ def test_cond_prior_instantiation_no_boundary_prior(self): self.assertIsNone(prior.boundary) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConditionalPriorDict(unittest.TestCase): def setUp(self): def condition_func_1(reference_parameters, var_0): @@ -208,7 +212,12 @@ def condition_func_3(reference_parameters, var_1, var_2): self.conditional_priors_manually_set_items = ( bilby.core.prior.ConditionalPriorDict() ) - self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4) + self.test_sample = dict( + var_0=self.xp.asarray(0.7), + var_1=self.xp.asarray(0.6), + var_2=self.xp.asarray(0.5), + var_3=self.xp.asarray(0.4), + ) self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( var_0=self.prior_0, @@ -260,12 +269,14 @@ def test_conditional_keys_setting_items(self): ) def test_prob(self): - self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample)) + prob = self.conditional_priors.prob(sample=self.test_sample) + self.assertEqual(self.test_value, prob) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_prob_illegal_conditions(self): del self.conditional_priors["var_0"] with self.assertRaises(bilby.core.prior.IllegalConditionsException): - self.conditional_priors.prob(sample=self.test_sample) + self.conditional_priors.prob(sample=self.test_sample, xp=self.xp) def test_ln_prob(self): self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample)) @@ -324,7 +335,7 @@ def test_rescale(self): expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res) + np.testing.assert_array_equal(expected, res) def test_rescale_with_joint_prior(self): """ @@ -349,19 +360,20 @@ def test_rescale_with_joint_prior(self): ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = list(self.test_sample.values()) + ref_variables = ref_variables[:2] + [0.1] + ref_variables[2:] + [0.4] + keys = list(self.test_sample.keys()) + keys = keys[:2] + ["mvgvar_0"] + keys[2:] + ["mvgvar_1"] res = priordict.rescale(keys=keys, theta=ref_variables) - self.assertIsInstance(res, list) self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(aac.get_namespace(res), self.xp) # check conditional values are still as expected expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + np.testing.assert_array_equal(expected, list(res)[:2] + list(res)[3:5]) def test_cdf(self): """ @@ -370,11 +382,11 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ sample = self.conditional_priors.sample() - self.assertEqual( + np.testing.assert_array_equal( self.conditional_priors.rescale( sample.keys(), self.conditional_priors.cdf(sample=sample).values() - ), list(sample.values()) + ), np.array(list(sample.values())) ) def test_rescale_illegal_conditions(self): @@ -446,6 +458,8 @@ def _tp_conditional_uniform(ref_params, period): prior.sample_subset(["tp"], 1000) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDirichletPrior(unittest.TestCase): def setUp(self): @@ -455,6 +469,10 @@ def tearDown(self): if os.path.isdir("priors"): shutil.rmtree("priors") + def test_samples_correct_type(self): + samples = self.priors.sample(10, xp=self.xp) + self.assertEqual(aac.get_namespace(samples["dirichlet_1"]), self.xp) + def test_samples_sum_to_less_than_one(self): """ Test that the samples sum to less than one as required for the diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 089611aee..425e1ff49 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -2,7 +2,9 @@ import unittest from unittest.mock import Mock, patch +import array_api_compat as aac import numpy as np +import pytest import bilby @@ -22,6 +24,8 @@ def __init__(self, names, bounds=None): setattr(bilby.core.prior, "FakeJointPriorDist", FakeJointPriorDist) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorDict(unittest.TestCase): def setUp(self): @@ -268,30 +272,40 @@ def test_dict_argument_is_not_string_or_dict(self): def test_sample_subset_correct_size(self): size = 7 samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size + keys=self.prior_set_from_dict.keys(), size=size, + xp=self.xp, ) self.assertEqual(len(self.prior_set_from_dict), len(samples)) for key in samples: self.assertEqual(size, len(samples[key])) + self.assertEqual(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_correct_size_when_non_priors_in_dict(self): self.prior_set_from_dict["asdf"] = "not_a_prior" samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys() + keys=self.prior_set_from_dict.keys(), + xp=self.xp, ) self.assertEqual(len(self.prior_set_from_dict) - 1, len(samples)) + for key in samples: + self.assertIsNotNone(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_with_actual_subset(self): size = 3 - samples = self.prior_set_from_dict.sample_subset(keys=["length"], size=size) - expected = dict(length=np.array([42.0, 42.0, 42.0])) + samples = self.prior_set_from_dict.sample_subset( + keys=["length"], size=size, xp=self.xp + ) + expected = dict(length=self.xp.asarray([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + self.assertEqual(aac.get_namespace(samples["length"]), self.xp) def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] - out = self.prior_set_from_dict.sample_subset_constrained_as_array(keys, size) - self.assertTrue(isinstance(out, np.ndarray)) + out = self.prior_set_from_dict.sample_subset_constrained_as_array( + keys, size, xp=self.xp + ) + self.assertEqual(aac.get_namespace(out), self.xp) self.assertTrue(out.shape == (len(keys), size)) def test_sample_subset_constrained(self): @@ -312,7 +326,7 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples1 = priors1.sample_subset_constrained( - keys=list(priors1.keys()), size=N + keys=list(priors1.keys()), size=N, xp=self.xp ) self.assertEqual(len(priors1) - 1, len(samples1)) for key in samples1: @@ -325,7 +339,7 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples2 = priors2.sample_subset_constrained( - keys=list(priors2.keys()), size=N + keys=list(priors2.keys()), size=N, xp=self.xp ) self.assertEqual(len(priors2), len(samples2)) for key in samples2: @@ -336,27 +350,31 @@ def test_sample(self): size = 7 bilby.core.utils.random.seed(42) samples1 = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size + keys=self.prior_set_from_dict.keys(), size=size, xp=self.xp ) bilby.core.utils.random.seed(42) - samples2 = self.prior_set_from_dict.sample(size=size) + samples2 = self.prior_set_from_dict.sample(size=size, xp=self.xp) self.assertEqual(set(samples1.keys()), set(samples2.keys())) for key in samples1: self.assertTrue(np.array_equal(samples1[key], samples2[key])) + self.assertEqual(aac.get_namespace(samples1[key]), self.xp) + self.assertEqual(aac.get_namespace(samples2[key]), self.xp) def test_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob( samples["speed"] ) self.assertEqual(expected, self.prior_set_from_dict.prob(samples)) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_ln_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) expected = self.first_prior.ln_prob( samples["mass"] ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_rescale(self): theta = [0.5, 0.5, 0.5] @@ -380,13 +398,14 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ - sample = self.prior_set_from_dict.sample() - original = np.array(list(sample.values())) - new = np.array(self.prior_set_from_dict.rescale( + sample = self.prior_set_from_dict.sample(xp=self.xp) + original = self.xp.asarray(list(sample.values())) + new = self.xp.asarray(self.prior_set_from_dict.rescale( sample.keys(), self.prior_set_from_dict.cdf(sample=sample).values() )) self.assertLess(max(abs(original - new)), 1e-10) + self.assertEqual(aac.get_namespace(new), self.xp) def test_redundancy(self): for key in self.prior_set_from_dict.keys(): diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 17d360d0c..67fbd2422 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -1,14 +1,37 @@ +import array_api_compat as aac import bilby import unittest import numpy as np import os +import pytest import scipy.stats as ss from scipy.integrate import trapezoid +aligned_prior_complex = bilby.gw.prior.AlignedSpin( + a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), + z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), + name="test", + unit="unit", + num_interp=1000, +) + +hp_map_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "prior_files/GW150914_testing_skymap.fits", +) +hp_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec"] +) +hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec", "testdistance"], distance=True +) + + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorClasses(unittest.TestCase): def setUp(self): - # set multivariate Gaussian mvg = bilby.core.prior.MultivariateGaussianDist( names=["testa", "testb"], @@ -22,16 +45,10 @@ def setUp(self): covs=np.array([[2.0, 0.5], [0.5, 2.0]]), weights=1.0, ) - hp_map_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "prior_files/GW150914_testing_skymap.fits", - ) - hp_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec"] - ) - hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec", "testdistance"], distance=True - ) + + # need to reset this for the repr test to get equality correct + hp_dist.requested_parameters = {"testra": None, "testdec": None} + hp_3d_dist.requested_parameters = {"testra": None, "testdec": None, "testdistance": None} def condition_func(reference_params, test_param): return reference_params.copy() @@ -102,13 +119,7 @@ def condition_func(reference_params, test_param): name="test", unit="unit", minimum=1e-2, maximum=1e2 ), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), - bilby.gw.prior.AlignedSpin( - a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), - z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), - name="test", - unit="unit", - num_interp=1000, - ), + aligned_prior_complex, bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), bilby.core.prior.MultivariateNormal(dist=mvn, name="testa", unit="unit"), @@ -243,6 +254,16 @@ def condition_func(reference_params, test_param): dist=hp_3d_dist, name="testdistance", unit="unit" ), ] + if aac.is_torch_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.Interped) + ] + elif aac.is_jax_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.StudentT) + ] def tearDown(self): del self.priors @@ -257,26 +278,35 @@ def test_minimum_rescaling(self): # the edge of the prior is extremely suppressed for these priors # and so the rescale function doesn't quite return the lower bound continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue elif bilby.core.prior.JointPrior in prior.__class__.__mro__: - minimum_sample = prior.rescale(0) + minimum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): - self.assertAlmostEqual(minimum_sample[0], prior.minimum) - self.assertAlmostEqual(minimum_sample[1], prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample[0]), prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample[1]), prior.minimum) else: - minimum_sample = prior.rescale(0) - self.assertAlmostEqual(minimum_sample, prior.minimum) + minimum_sample = prior.rescale(self.xp.asarray(0)) + self.assertAlmostEqual(np.asarray(minimum_sample), prior.minimum) def test_maximum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: - maximum_sample = prior.rescale(0) + maximum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): - self.assertAlmostEqual(maximum_sample[0], prior.maximum) - self.assertAlmostEqual(maximum_sample[1], prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample[0]), prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample[1]), prior.maximum) + elif isinstance(prior, bilby.gw.prior.AlignedSpin): + maximum_sample = prior.rescale(self.xp.asarray(1)) + self.assertGreater(np.asarray(maximum_sample), 0.997) else: - maximum_sample = prior.rescale(1) - self.assertAlmostEqual(maximum_sample, prior.maximum) + maximum_sample = prior.rescale(self.xp.asarray(1)) + self.assertAlmostEqual(np.asarray(maximum_sample), prior.maximum) def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" @@ -284,20 +314,25 @@ def test_many_sample_rescaling(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue + many_samples = prior.rescale(self.xp.asarray(np.random.uniform(0, 1, 1000))) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): continue self.assertTrue( all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) ) + self.assertEqual(aac.get_namespace(many_samples), self.xp) def test_least_recently_sampled(self): for prior in self.priors: - least_recently_sampled_expected = prior.sample() + least_recently_sampled_expected = prior.sample(xp=self.xp) self.assertEqual( least_recently_sampled_expected, prior.least_recently_sampled ) + self.assertEqual(aac.get_namespace(least_recently_sampled_expected), self.xp) def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" @@ -305,10 +340,11 @@ def test_sampling_single(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - single_sample = prior.sample() + single_sample = prior.sample(xp=self.xp) self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) ) + self.assertEqual(aac.get_namespace(single_sample), self.xp) def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" @@ -316,17 +352,18 @@ def test_sampling_many(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.sample(5000) + many_samples = prior.sample(5000, xp=self.xp) self.assertTrue( (all(many_samples >= prior.minimum)) & (all(many_samples <= prior.maximum)) ) + self.assertEqual(aac.get_namespace(many_samples), self.xp) def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -342,7 +379,7 @@ def test_probability_below_domain(self): # SymmetricLogUniform has support down to -maximum continue if prior.minimum != -np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -353,31 +390,39 @@ def test_probability_below_domain(self): def test_least_recently_sampled_2(self): for prior in self.priors: - lrs = prior.sample() + lrs = prior.sample(xp=self.xp) self.assertEqual(lrs, prior.least_recently_sampled) + self.assertEqual(aac.get_namespace(lrs), self.xp) def test_prob_and_ln_prob(self): for prior in self.priors: - sample = prior.sample() + sample = prior.sample(xp=self.xp) if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa # due to the way that the Multivariate Gaussian prior must sequentially call # the prob and ln_prob functions, it must be ignored in this test. - self.assertAlmostEqual( - np.log(prior.prob(sample)), prior.ln_prob(sample), 12 - ) + lnprob = prior.ln_prob(sample) + prob = prior.prob(sample) + self.assertEqual(aac.get_namespace(lnprob), self.xp) + self.assertEqual(aac.get_namespace(prob), self.xp) + # lower precision for jax running tests with float32 + lnprob = np.asarray(lnprob) + prob = np.asarray(prob) + self.assertAlmostEqual(np.log(prob), lnprob, 6) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: - samples = prior.sample(10) + samples = prior.sample(10, xp=self.xp) if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa ln_probs = prior.ln_prob(samples) probs = prior.prob(samples) for sample, logp, p in zip(samples, ln_probs, probs): self.assertAlmostEqual(prior.ln_prob(sample), logp) self.assertAlmostEqual(prior.prob(sample), p) + self.assertEqual(aac.get_namespace(ln_probs), self.xp) + self.assertEqual(aac.get_namespace(probs), self.xp) def test_cdf_is_inverse_of_rescaling(self): - domain = np.linspace(0, 1, 100) + domain = self.xp.linspace(0, 1, 100) threshold = 1e-9 for prior in self.priors: if ( @@ -385,6 +430,9 @@ def test_cdf_is_inverse_of_rescaling(self): or bilby.core.prior.JointPrior in prior.__class__.__mro__ ): continue + elif isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): rescaled = prior.rescale(domain) cdf_vals = prior.cdf(rescaled) @@ -392,15 +440,21 @@ def test_cdf_is_inverse_of_rescaling(self): cdf_vals_2 = prior.cdf(rescaled_2) self.assertTrue(np.array_equal(rescaled, rescaled_2)) max_difference = max(np.abs(cdf_vals - cdf_vals_2)) + for arr in [rescaled, rescaled_2, cdf_vals, cdf_vals_2]: + self.assertEqual(aac.get_namespace(arr), self.xp) else: rescaled = prior.rescale(domain) max_difference = max(np.abs(domain - prior.cdf(rescaled))) + self.assertEqual(aac.get_namespace(rescaled), self.xp) self.assertLess(max_difference, threshold) def test_cdf_one_above_domain(self): for prior in self.priors: + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if prior.maximum != np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) self.assertTrue(all(prior.cdf(outside_domain) == 1)) @@ -410,13 +464,16 @@ def test_cdf_zero_below_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf ): continue if prior.minimum != -np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) @@ -564,11 +621,20 @@ def test_probability_in_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.minimum == -np.inf: - prior.minimum = -1e5 + minimum = -1e5 + else: + minimum = prior.minimum if prior.maximum == np.inf: - prior.maximum = 1e5 - domain = np.linspace(prior.minimum, prior.maximum, 1000) - self.assertTrue(all(prior.prob(domain) >= 0)) + maximum = 1e5 + else: + maximum = prior.maximum + domain = self.xp.linspace(minimum, maximum, 1000) + print(prior) + prob = prior.prob(domain) + print(min(prob)) + self.assertEqual(aac.get_namespace(prob), self.xp) + prob = np.asarray(prob) + self.assertTrue(all(prob >= 0)) def test_probability_surrounding_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" @@ -579,13 +645,14 @@ def test_probability_surrounding_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) - indomain = (surround_domain >= prior.minimum) | ( - surround_domain <= prior.maximum - ) - outdomain = (surround_domain < prior.minimum) | ( - surround_domain > prior.maximum - ) + with np.errstate(invalid="ignore"): + surround_domain = self.xp.linspace(prior.minimum - 1, prior.maximum + 1, 1000) + indomain = (surround_domain >= prior.minimum) | ( + surround_domain <= prior.maximum + ) + outdomain = (surround_domain < prior.minimum) | ( + surround_domain > prior.maximum + ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): continue @@ -633,11 +700,15 @@ def test_normalized(self): domain = np.linspace(prior.minimum, prior.maximum, 10000) elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): domain = prior.values - self.assertTrue(np.sum(prior.prob(domain)) == 1) + probs = prior.prob(self.xp.asarray(domain)) + self.assertEqual(aac.get_namespace(probs), self.xp) + self.assertTrue(np.sum(np.asarray(probs)) == 1) continue else: domain = np.linspace(prior.minimum, prior.maximum, 1000) - self.assertAlmostEqual(trapezoid(prior.prob(domain), domain), 1, 3) + probs = prior.prob(self.xp.asarray(domain)) + self.assertAlmostEqual(trapezoid(np.array(probs), domain), 1, 3) + self.assertEqual(aac.get_namespace(probs), self.xp) def test_accuracy(self): """Test that each of the priors' functions is calculated accurately, as compared to scipy's calculations""" @@ -732,11 +803,14 @@ def test_accuracy(self): bilby.core.prior.WeightedDiscreteValues, ) if isinstance(prior, (testTuple)): - np.testing.assert_almost_equal(prior.prob(domain), scipy_prob) - np.testing.assert_almost_equal(prior.ln_prob(domain), scipy_lnprob) - np.testing.assert_almost_equal(prior.cdf(domain), scipy_cdf) + np.testing.assert_almost_equal(prior.prob(self.xp.asarray(domain)), scipy_prob) + np.testing.assert_almost_equal(prior.ln_prob(self.xp.asarray(domain)), scipy_lnprob) + np.testing.assert_almost_equal(prior.cdf(self.xp.asarray(domain)), scipy_cdf) + if isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue np.testing.assert_almost_equal( - prior.rescale(rescale_domain), scipy_rescale + prior.rescale(self.xp.asarray(rescale_domain)), scipy_rescale ) def test_unit_setting(self): @@ -788,6 +862,7 @@ def test_repr(self): repr_prior_string = repr_prior_string.replace( "HealPixMapPriorDist", "bilby.gw.prior.HealPixMapPriorDist" ) + prior.dist.rescale_parameters = {key: None for key in prior.dist.names} elif isinstance(prior, bilby.gw.prior.UniformComovingVolume): repr_prior_string = "bilby.gw.prior." + repr(prior) elif "Conditional" in prior.__class__.__name__: @@ -821,7 +896,7 @@ def test_set_maximum_setting(self): ): continue prior.maximum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(max(prior.sample(10000)) < prior.maximum) + self.assertTrue(max(prior.sample(10000, xp=self.xp)) < prior.maximum) def test_set_minimum_setting(self): for prior in self.priors: @@ -847,7 +922,7 @@ def test_set_minimum_setting(self): ): continue prior.minimum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(min(prior.sample(10000)) > prior.minimum) + self.assertTrue(min(prior.sample(10000, xp=self.xp)) > prior.minimum) if __name__ == "__main__": diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index d2cdcc55a..7c5716b8a 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -1,6 +1,9 @@ -import numpy as np import unittest +import array_api_compat as aac +import numpy as np +import pytest + import bilby from bilby.core.prior.slabspike import SlabSpikePrior from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \ @@ -60,13 +63,15 @@ def test_set_spike_height_domain_edge(self): self.prior.spike_height = 1 +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSlabSpikeClasses(unittest.TestCase): def setUp(self): - self.minimum = 0.4 - self.maximum = 2.4 - self.spike_loc = 1.5 - self.spike_height = 0.3 + self.minimum = self.xp.asarray(0.4) + self.maximum = self.xp.asarray(2.4) + self.spike_loc = self.xp.asarray(1.5) + self.spike_height = self.xp.asarray(0.3) self.slabs = [ Uniform(minimum=self.minimum, maximum=self.maximum), @@ -75,20 +80,22 @@ def setUp(self): TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1), Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1), Gaussian(mu=0, sigma=1), - Cosine(), - Sine(), + Cosine(minimum=self.xp.asarray(-np.pi / 2), maximum=self.xp.asarray(np.pi / 2)), + Sine(minimum=self.xp.asarray(0), maximum=self.xp.asarray(np.pi)), HalfGaussian(sigma=1), LogNormal(mu=1, sigma=2), Exponential(mu=2), - StudentT(df=2), Logistic(mu=2, scale=1), Cauchy(alpha=1, beta=2), Gamma(k=1, theta=1.), - ChiSquared(nu=2)] + ChiSquared(nu=2), + ] + if not aac.is_jax_namespace(self.xp): + StudentT(df=2), self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) for slab in self.slabs] - self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000) - self.test_nodes_infinite_support = np.linspace(-10, 10, 1000) + self.test_nodes_finite_support = self.xp.linspace(self.minimum, self.maximum, 1000) + self.test_nodes_infinite_support = self.xp.linspace(-10, 10, 1000) self.test_nodes = [self.test_nodes_finite_support if np.isinf(slab.minimum) or np.isinf(slab.maximum) else self.test_nodes_finite_support for slab in self.slabs] @@ -107,6 +114,7 @@ def test_prob_on_slab(self): expected = slab.prob(test_nodes) * slab_spike.slab_fraction actual = slab_spike.prob(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_prob_on_spike(self): for slab_spike in self.slab_spikes: @@ -117,10 +125,13 @@ def test_ln_prob_on_slab(self): expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction) actual = slab_spike.ln_prob(test_nodes) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_ln_prob_on_spike(self): for slab_spike in self.slab_spikes: - self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc)) + actual = slab_spike.ln_prob(self.spike_loc) + self.assertEqual(np.inf, actual) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_inverse_cdf_below_spike_with_spike_at_minimum(self): for slab in self.slabs: @@ -143,19 +154,22 @@ def test_cdf_below_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction - actual = slab_spike.cdf(self.spike_loc) + actual = slab_spike.cdf(self.xp.asarray(self.spike_loc)) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) - self.assertTrue(np.array_equal(expected, actual)) + np.testing.assert_allclose(expected, actual, rtol=1e-12) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_minimum(self): for slab_spike in self.slab_spikes: @@ -172,31 +186,39 @@ def test_cdf_at_maximum(self): def test_rescale_no_spike(self): for slab in self.slabs: slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum) - vals = np.linspace(0, 1, 1000) + vals = self.xp.linspace(0, 1, 1000) expected = slab.rescale(vals) actual = slab_spike.rescale(vals) - print(slab) + self.assertEqual(aac.get_namespace(actual), self.xp) + self.assertEqual(aac.get_namespace(expected), self.xp) + actual = np.asarray(actual) + expected = np.asarray(expected) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) def test_rescale_below_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) + vals = self.xp.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) expected = slab.rescale(vals / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike, - slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000) - expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) + vals = self.xp.linspace( + slab_spike.inverse_cdf_below_spike, + slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000 + ) + expected = self.xp.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_above_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) - expected = np.ones(len(vals)) * slab.rescale( + vals = self.xp.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) + expected = self.xp.ones(len(vals)) * slab.rescale( (vals - self.spike_height) / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(aac.get_namespace(actual), self.xp) diff --git a/test/core/result_test.py b/test/core/result_test.py index dc13a20e8..aa10b26b4 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -13,6 +13,8 @@ from bilby.core.utils import logger +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestJson(unittest.TestCase): def setUp(self): @@ -28,12 +30,12 @@ def test_list_encoding(self): self.assertTrue(np.all(data["x"] == decoded["x"])) def test_array_encoding(self): - data = dict(x=np.array([1, 2, 3.4])) + data = dict(x=self.xp.asarray([1, 2, 3.4])) encoded = json.dumps(data, cls=self.encoder) decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data["x"]), type(decoded["x"])) - self.assertTrue(np.all(data["x"] == decoded["x"])) + self.assertTrue(self.xp.all(data["x"] == decoded["x"])) def test_complex_encoding(self): data = dict(x=1 + 3j) @@ -918,6 +920,8 @@ def test_reweight_different_likelihood_weights_correct(self): self.assertNotEqual(new.log_evidence, self.result.log_evidence) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestResultSaveAndRead(unittest.TestCase): @pytest.fixture(autouse=True) @@ -943,7 +947,11 @@ def setUp(self): search_parameter_keys=["x", "y"], fixed_parameter_keys=["c", "d"], priors=priors, - sampler_kwargs=dict(test="test", func=lambda x: x), + sampler_kwargs=dict( + test="test", + func=lambda x: x, + some_array=self.xp.ones((5, 5)), + ), injection_parameters=dict(x=0.5, y=0.5), meta_data=dict(test="test"), sampling_time=100.0, diff --git a/test/core/series_test.py b/test/core/series_test.py index bf1b19c43..c2b8dccdb 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -1,15 +1,20 @@ import unittest + +import array_api_compat as aac import numpy as np +import pytest from bilby.core.utils import create_frequency_series, create_time_series from bilby.core.series import CoupledTimeAndFrequencySeries +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCoupledTimeAndFrequencySeries(unittest.TestCase): def setUp(self): - self.duration = 2 - self.sampling_frequency = 4096 - self.start_time = -1 + self.duration = self.xp.asarray(2.0) + self.sampling_frequency = self.xp.asarray(4096.0) + self.start_time = self.xp.asarray(-1.0) self.series = CoupledTimeAndFrequencySeries( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -43,10 +48,10 @@ def test_start_time_from_init(self): self.assertEqual(self.start_time, self.series.start_time) def test_frequency_array_type(self): - self.assertIsInstance(self.series.frequency_array, np.ndarray) + self.assertEqual(aac.get_namespace(self.series.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.series.time_array, np.ndarray) + self.assertEqual(aac.get_namespace(self.series.time_array), self.xp) def test_frequency_array_from_init(self): expected = create_frequency_series( @@ -63,8 +68,8 @@ def test_time_array_from_init(self): self.assertTrue(np.array_equal(expected, self.series.time_array)) def test_frequency_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) new_frequency_array = create_frequency_series( sampling_frequency=new_sampling_frequency, duration=new_duration ) @@ -79,9 +84,9 @@ def test_frequency_array_setter(self): self.assertAlmostEqual(self.start_time, self.series.start_time) def test_time_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 - new_start_time = 4 + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) + new_start_time = self.xp.asarray(4.0) new_time_array = create_time_series( sampling_frequency=new_sampling_frequency, duration=new_duration, @@ -90,31 +95,31 @@ def test_time_array_setter(self): self.series.time_array = new_time_array self.assertTrue(np.array_equal(new_time_array, self.series.time_array)) self.assertAlmostEqual( - new_sampling_frequency, self.series.sampling_frequency, places=1 + np.asarray(new_sampling_frequency), np.asarray(self.series.sampling_frequency), places=1 ) - self.assertAlmostEqual(new_duration, self.series.duration, places=1) - self.assertAlmostEqual(new_start_time, self.series.start_time, places=1) + self.assertAlmostEqual(np.asarray(new_duration), np.asarray(self.series.duration), places=1) + self.assertAlmostEqual(np.asarray(new_start_time), np.asarray(self.series.start_time), places=1) def test_time_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.time_array def test_time_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.time_array def test_frequency_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.frequency_array def test_frequency_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.frequency_array diff --git a/test/core/utils_test.py b/test/core/utils_test.py index df46d6bb3..d8a78beee 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -1,6 +1,8 @@ import unittest import os +import array_api_compat as aac +import array_api_extra as xpx import dill import numpy as np from astropy import constants @@ -49,35 +51,42 @@ def test_gravitational_constant(self): self.assertEqual(bilby.core.utils.gravitational_constant, lal.G_SI) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFFT(unittest.TestCase): def setUp(self): - self.sampling_frequency = 10 + self.sampling_frequency = self.xp.asarray(10) def tearDown(self): del self.sampling_frequency def test_nfft_sine_function(self): - injected_frequency = 2.7324 - duration = 100 - times = utils.create_time_series(self.sampling_frequency, duration) + xp = self.xp + injected_frequency = xp.asarray(2.7324) + duration = xp.asarray(100) + times = utils.create_time_series(xp.asarray(self.sampling_frequency), duration) - time_domain_strain = np.sin(2 * np.pi * times * injected_frequency + 0.4) + time_domain_strain = xp.sin(2 * np.pi * times * injected_frequency + 0.4) frequency_domain_strain, frequencies = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) - frequency_at_peak = frequencies[np.argmax(np.abs(frequency_domain_strain))] + frequency_at_peak = frequencies[xp.argmax(abs(frequency_domain_strain))] + self.assertEqual(aac.get_namespace(frequency_at_peak), xp) + frequency_at_peak = np.asarray(frequency_at_peak) + injected_frequency = np.asarray(injected_frequency) self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1) def test_nfft_infft(self): - time_domain_strain = np.random.normal(0, 1, 10) + xp = self.xp + time_domain_strain = xp.asarray(np.random.normal(0, 1, 10)) frequency_domain_strain, _ = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) new_time_domain_strain = bilby.core.utils.infft( frequency_domain_strain, self.sampling_frequency ) - self.assertTrue(np.allclose(time_domain_strain, new_time_domain_strain)) + self.assertTrue(xp.allclose(time_domain_strain, new_time_domain_strain)) class TestInferParameters(unittest.TestCase): @@ -119,11 +128,13 @@ def test_self_handling_method_as_function(self): self.assertListEqual(expected, actual) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeAndFrequencyArrays(unittest.TestCase): def setUp(self): - self.start_time = 1.3 - self.sampling_frequency = 5 - self.duration = 1.6 + self.start_time = self.xp.asarray(1.3) + self.sampling_frequency = self.xp.asarray(5) + self.duration = self.xp.asarray(1.6) self.frequency_array = utils.create_frequency_series( sampling_frequency=self.sampling_frequency, duration=self.duration ) @@ -141,12 +152,13 @@ def tearDown(self): del self.time_array def test_create_time_array(self): - expected_time_array = np.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) + expected_time_array = self.xp.asarray([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) time_array = utils.create_time_series( sampling_frequency=self.sampling_frequency, duration=self.duration, starting_time=self.start_time, ) + self.assertEqual(aac.get_namespace(time_array), self.xp) self.assertTrue(np.allclose(expected_time_array, time_array)) def test_create_frequency_array(self): @@ -164,7 +176,7 @@ def test_get_sampling_frequency_from_time_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_time_array_unequally_sampled(self): - self.time_array[-1] += 0.0001 + self.time_array = xpx.at(self.time_array, -1).set(self.time_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_time_array( self.time_array @@ -190,7 +202,9 @@ def test_get_sampling_frequency_from_frequency_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): - self.frequency_array[-1] += 0.0001 + self.frequency_array = xpx.at( + self.frequency_array, -1 + ).set(self.frequency_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array( self.frequency_array @@ -233,34 +247,38 @@ def test_consistency_frequency_array_to_frequency_array(self): def test_illegal_sampling_frequency_and_duration(self): with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): _ = utils.create_time_series( - sampling_frequency=7.7, duration=1.3, starting_time=0 + sampling_frequency=self.xp.asarray(7.7), + duration=self.xp.asarray(1.3), + starting_time=self.xp.asarray(0), ) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestReflect(unittest.TestCase): def test_in_range(self): - xprime = np.array([0.1, 0.5, 0.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([0.1, 0.5, 0.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_one_to_two(self): - xprime = np.array([1.1, 1.5, 1.9]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([1.1, 1.5, 1.9]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_two_to_three(self): - xprime = np.array([2.1, 2.5, 2.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([2.1, 2.5, 2.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_one_to_zero(self): - xprime = np.array([-0.9, -0.5, -0.1]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([-0.9, -0.5, -0.1]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_two_to_minus_one(self): - xprime = np.array([-1.9, -1.5, -1.1]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([-1.9, -1.5, -1.1]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) @@ -325,8 +343,12 @@ def plot(): self.assertTrue(os.path.isfile(self.filename)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestUnsortedInterp2d(unittest.TestCase): def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Skipping Interp2d tests for torch backend") self.xx = np.linspace(0, 1, 10) self.yy = np.linspace(0, 1, 10) self.zz = np.random.random((10, 10)) @@ -343,36 +365,42 @@ def test_returns_none_for_floats_outside_range(self): self.assertIsNone(self.interpolant(-0.5, 0.5)) def test_returns_float_for_float_and_array(self): - self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray) - self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray) - self.assertIsInstance( - self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray + input_array = self.xp.asarray(np.random.random(10)) + self.assertEqual(aac.get_namespace(self.interpolant(input_array, 0.5)), self.xp) + self.assertEqual(aac.get_namespace( + self.interpolant(input_array, input_array)), self.xp ) + self.assertEqual(aac.get_namespace(self.interpolant(0.5, input_array)), self.xp) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): - self.interpolant(np.random.random(10), np.random.random(20)) + self.interpolant( + self.xp.asarray(np.random.random(10)), + self.xp.asarray(np.random.random(20)), + ) def test_returns_fill_in_correct_place(self): - x_data = np.random.random(10) - y_data = np.random.random(10) - x_data[3] = -1 - self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) + x_data = self.xp.asarray(np.random.random(10)) + y_data = self.xp.asarray(np.random.random(10)) + x_data = xpx.at(x_data, 3).set(-1) + self.assertTrue(self.xp.isnan(self.interpolant(x_data, y_data)[3])) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTrapeziumRuleIntegration(unittest.TestCase): def setUp(self): - self.x = np.linspace(0, 1, 100) - self.dxs = np.diff(self.x) + self.x = self.xp.linspace(0, 1, 100) + self.dxs = self.xp.diff(self.x) self.dx = self.dxs[0] with np.errstate(divide="ignore"): - self.lnfunc1 = np.log(self.x) + self.lnfunc1 = self.xp.log(self.x) self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 with np.errstate(divide="ignore"): - self.lnfunc2 = np.log(self.x ** 2) + self.lnfunc2 = self.xp.log(self.x ** 2) self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 - self.irregularx = np.array( + self.irregularx = self.xp.asarray( [ self.x[0], self.x[12], @@ -390,9 +418,9 @@ def setUp(self): ] ) with np.errstate(divide="ignore"): - self.lnfunc1irregular = np.log(self.irregularx) - self.lnfunc2irregular = np.log(self.irregularx ** 2) - self.irregulardxs = np.diff(self.irregularx) + self.lnfunc1irregular = self.xp.log(self.irregularx) + self.lnfunc2irregular = self.xp.log(self.irregularx ** 2) + self.irregulardxs = self.xp.diff(self.irregularx) def test_incorrect_step_type(self): with self.assertRaises(TypeError): @@ -407,19 +435,19 @@ def test_integral_func1(self): res2 = utils.logtrapzexp(self.lnfunc1, self.dxs) self.assertTrue(np.abs(res1 - res2) < 1e-12) - self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res1) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2(self): res = utils.logtrapzexp(self.lnfunc2, self.dxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-4) def test_integral_func1_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-2) class TestSavingNumpyRandomGenerator(unittest.TestCase): diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index fc0f4321a..9d2d46b48 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -1,25 +1,29 @@ import unittest +import array_api_compat as aac import numpy as np import pandas as pd +import pytest import bilby from bilby.gw import conversion +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicConversions(unittest.TestCase): def setUp(self): - self.mass_1 = 1.4 - self.mass_2 = 1.3 - self.mass_ratio = 13 / 14 - self.total_mass = 2.7 - self.chirp_mass = (1.4 * 1.3) ** 0.6 / 2.7 ** 0.2 - self.symmetric_mass_ratio = (1.4 * 1.3) / 2.7 ** 2 - self.cos_angle = -1 - self.angle = np.pi - self.lambda_1 = 300 - self.lambda_2 = 300 * (14 / 13) ** 5 - self.lambda_tilde = ( + self.mass_1 = self.xp.asarray(1.4) + self.mass_2 = self.xp.asarray(1.3) + self.mass_ratio = self.xp.asarray(13 / 14) + self.total_mass = self.xp.asarray(2.7) + self.chirp_mass = (self.mass_1 * self.mass_2) ** 0.6 / self.total_mass ** 0.2 + self.symmetric_mass_ratio = (self.mass_1 * self.mass_2) / self.total_mass ** 2 + self.cos_angle = self.xp.asarray(-1.0) + self.angle = self.xp.pi + self.lambda_1 = self.xp.asarray(300.0) + self.lambda_2 = self.xp.asarray(300.0 * (14 / 13) ** 5) + self.lambda_tilde = self.xp.asarray( 8 / 13 * ( @@ -38,7 +42,7 @@ def setUp(self): * (self.lambda_1 - self.lambda_2) ) ) - self.delta_lambda_tilde = ( + self.delta_lambda_tilde = self.xp.asarray( 1 / 2 * ( @@ -74,30 +78,36 @@ def test_total_mass_and_mass_ratio_to_component_masses(self): self.assertTrue( all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5]) ) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_chirp_mass_and_primary_mass_to_mass_ratio(self): mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( self.chirp_mass, self.mass_1 ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_symmetric_mass_ratio_to_mass_ratio(self): mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( self.symmetric_mass_ratio ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.chirp_mass_and_total_mass_to_symmetric_mass_ratio( self.chirp_mass, self.total_mass ) - self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertAlmostEqual(float(self.symmetric_mass_ratio), float(symmetric_mass_ratio)) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_chirp_mass_and_mass_ratio_to_total_mass(self): total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( self.chirp_mass, self.mass_ratio ) - self.assertAlmostEqual(self.total_mass, total_mass) + self.assertAlmostEqual(float(self.total_mass), float(total_mass)) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_chirp_mass_and_mass_ratio_to_component_masses(self): mass_1, mass_2 = \ @@ -105,30 +115,37 @@ def test_chirp_mass_and_mass_ratio_to_component_masses(self): self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.mass_1, mass_1) self.assertAlmostEqual(self.mass_2, mass_2) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_component_masses_to_chirp_mass(self): chirp_mass = conversion.component_masses_to_chirp_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.chirp_mass, chirp_mass) + self.assertEqual(aac.get_namespace(chirp_mass), self.xp) def test_component_masses_to_total_mass(self): total_mass = conversion.component_masses_to_total_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.total_mass, total_mass) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_component_masses_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio( self.mass_1, self.mass_2 ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_component_masses_to_mass_ratio(self): mass_ratio = conversion.component_masses_to_mass_ratio(self.mass_1, self.mass_2) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_mass_1_and_chirp_mass_to_mass_ratio(self): mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio( self.mass_1, self.chirp_mass ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_lambda_tilde_to_lambda_1_lambda_2(self): lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2( @@ -142,6 +159,8 @@ def test_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ( @@ -158,18 +177,22 @@ def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_1_lambda_2_to_lambda_tilde(self): lambda_tilde = conversion.lambda_1_lambda_2_to_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.lambda_tilde - lambda_tilde) < 1e-5) + self.assertEqual(aac.get_namespace(lambda_tilde), self.xp) def test_lambda_1_lambda_2_to_delta_lambda_tilde(self): delta_lambda_tilde = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5) + self.assertEqual(aac.get_namespace(delta_lambda_tilde), self.xp) def test_identity_conversion(self): original_samples = dict( @@ -600,18 +623,20 @@ def test_comoving_luminosity_with_cosmology(self): self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGenerateMassParameters(unittest.TestCase): def setUp(self): - self.expected_values = {'mass_1': 2.0, - 'mass_2': 1.0, - 'chirp_mass': 1.2167286837864113, - 'total_mass': 3.0, - 'mass_1_source': 4.0, - 'mass_2_source': 2.0, - 'chirp_mass_source': 2.433457367572823, - 'total_mass_source': 6, - 'symmetric_mass_ratio': 0.2222222222222222, - 'mass_ratio': 0.5} + self.expected_values = {'mass_1': self.xp.asarray(2.0), + 'mass_2': self.xp.asarray(1.0), + 'chirp_mass': self.xp.asarray(1.2167286837864113), + 'total_mass': self.xp.asarray(3.0), + 'mass_1_source': self.xp.asarray(4.0), + 'mass_2_source': self.xp.asarray(2.0), + 'chirp_mass_source': self.xp.asarray(2.433457367572823), + 'total_mass_source': self.xp.asarray(6), + 'symmetric_mass_ratio': self.xp.asarray(0.2222222222222222), + 'mass_ratio': self.xp.asarray(0.5)} def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses @@ -627,8 +652,8 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): self.assertTrue("mass_2" in local_test_vars_with_component_masses.keys()) for key in local_test_vars_with_component_masses.keys(): self.assertAlmostEqual( - local_test_vars_with_component_masses[key], - self.expected_values[key]) + np.asarray(local_test_vars_with_component_masses[key]), + np.asarray(self.expected_values[key])) # Test the function more generally local_all_mass_parameters = \ @@ -658,7 +683,14 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): ) ) for key in local_all_mass_parameters.keys(): - self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) + self.assertAlmostEqual( + np.asarray(expected_values[key]), + np.asarray(local_all_mass_parameters[key]), + ) + self.assertEqual( + aac.get_namespace(local_all_mass_parameters[key]), + self.xp, + ) def test_from_mass_1_and_mass_2(self): self.helper_generation_from_keys(["mass_1", "mass_2"], @@ -725,6 +757,8 @@ def test_from_chirp_mass_source_and_symmetric_mass_2(self): self.expected_values, source=True) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestEquationOfStateConversions(unittest.TestCase): ''' Class to test equation of state conversions. @@ -733,48 +767,48 @@ class TestEquationOfStateConversions(unittest.TestCase): ''' def setUp(self): - self.mass_1_source_spectral = [ + self.mass_1_source_spectral = self.xp.asarray([ 4.922542724434885, 4.350626907771598, 4.206155335439082, 1.7822696459661311, 1.3091740103047926 - ] - self.mass_2_source_spectral = [ + ]) + self.mass_2_source_spectral = self.xp.asarray([ 3.459974694590303, 1.2276461777181447, 3.7287707089639976, 0.3724016563531846, 1.055042934805801 - ] - self.spectral_pca_gamma_0 = [ + ]) + self.spectral_pca_gamma_0 = self.xp.asarray([ 0.7074873121348357, 0.05855931126849878, 0.7795329261793462, 1.467907561566463, 2.9066488405635624 - ] - self.spectral_pca_gamma_1 = [ + ]) + self.spectral_pca_gamma_1 = self.xp.asarray([ -0.29807111670823816, 2.027708558522935, -1.4415775226512115, -0.7104870098896858, -0.4913817181089619 - ] - self.spectral_pca_gamma_2 = [ + ]) + self.spectral_pca_gamma_2 = self.xp.asarray([ 0.25625095371021156, -0.19574096643220049, -0.2710238103460012, 0.22815820981582358, -0.1543413205016374 - ] - self.spectral_pca_gamma_3 = [ + ]) + self.spectral_pca_gamma_3 = self.xp.asarray([ -0.04030365100175101, 0.05698030777919032, -0.045595911403040264, -0.023480394227900117, -0.07114492992285618 - ] + ]) self.spectral_gamma_0 = [ 1.1259406796075457, 0.3191335618787259, @@ -875,10 +909,12 @@ def test_spectral_pca_to_spectral(self): self.spectral_pca_gamma_2[i], self.spectral_pca_gamma_3[i] ) - self.assertAlmostEqual(spectral_gamma_0, self.spectral_gamma_0[i], places=5) - self.assertAlmostEqual(spectral_gamma_1, self.spectral_gamma_1[i], places=5) - self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) - self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_0), self.spectral_gamma_0[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_1), self.spectral_gamma_1[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_2), self.spectral_gamma_2[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_3), self.spectral_gamma_3[i], places=5) + for val in [spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3]: + self.assertEqual(aac.get_namespace(val), self.xp) def test_spectral_params_to_lambda_1_lambda_2(self): ''' @@ -906,8 +942,8 @@ def test_spectral_params_to_lambda_1_lambda_2(self): self.mass_1_source_spectral[i], self.mass_2_source_spectral[i] ) - self.assertAlmostEqual(self.lambda_1_spectral[i], lambda_1, places=0) - self.assertAlmostEqual(self.lambda_2_spectral[i], lambda_2, places=0) + self.assertAlmostEqual(self.lambda_1_spectral[i], float(lambda_1), places=0) + self.assertAlmostEqual(self.lambda_2_spectral[i], float(lambda_2), places=0) self.assertAlmostEqual(self.eos_check_spectral[i], eos_check) def test_polytrope_or_causal_params_to_lambda_1_lambda_2_causal(self): diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 358825b23..7340a5f8d 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -1,11 +1,15 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np +import pytest import bilby +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestInterferometerGeometry(unittest.TestCase): def setUp(self): self.length = 30 @@ -26,6 +30,7 @@ def setUp(self): xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt, ) + self.geometry.set_array_backend(self.xp) def tearDown(self): del self.length @@ -40,27 +45,35 @@ def tearDown(self): def test_length_setting(self): self.assertEqual(self.geometry.length, self.length) + self.assertEqual(aac.get_namespace(self.geometry.length), self.xp) def test_latitude_setting(self): self.assertEqual(self.geometry.latitude, self.latitude) + self.assertEqual(aac.get_namespace(self.geometry.latitude), self.xp) def test_longitude_setting(self): self.assertEqual(self.geometry.longitude, self.longitude) + self.assertEqual(aac.get_namespace(self.geometry.longitude), self.xp) def test_elevation_setting(self): self.assertEqual(self.geometry.elevation, self.elevation) + self.assertEqual(aac.get_namespace(self.geometry.elevation), self.xp) def test_xarm_azi_setting(self): self.assertEqual(self.geometry.xarm_azimuth, self.xarm_azimuth) + self.assertEqual(aac.get_namespace(self.geometry.xarm_azimuth), self.xp) def test_yarm_azi_setting(self): self.assertEqual(self.geometry.yarm_azimuth, self.yarm_azimuth) + self.assertEqual(aac.get_namespace(self.geometry.yarm_azimuth), self.xp) def test_xarm_tilt_setting(self): self.assertEqual(self.geometry.xarm_tilt, self.xarm_tilt) + self.assertEqual(aac.get_namespace(self.geometry.xarm_tilt), self.xp) def test_yarm_tilt_setting(self): self.assertEqual(self.geometry.yarm_tilt, self.yarm_tilt) + self.assertEqual(aac.get_namespace(self.geometry.yarm_tilt), self.xp) def test_vertex_without_update(self): _ = self.geometry.vertex @@ -141,32 +154,38 @@ def test_y_with_latitude_update(self): def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_unit_vector_along_arm_default(self): with self.assertRaises(ValueError): @@ -177,17 +196,20 @@ def test_unit_vector_along_arm_x(self): self.geometry.latitude = 0 self.geometry.xarm_tilt = 0 self.geometry.xarm_azimuth = 0 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("x") self.assertTrue(np.allclose(arm, np.array([0, 1, 0]))) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_unit_vector_along_arm_y(self): self.geometry.longitude = 0 self.geometry.latitude = 0 self.geometry.yarm_tilt = 0 self.geometry.yarm_azimuth = 90 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("y") - print(arm) self.assertTrue(np.allclose(arm, np.array([0, 0, 1]))) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_repr(self): expected = ( diff --git a/test/gw/likelihood/marginalization_test.py b/test/gw/likelihood/marginalization_test.py index 351e516f8..b5da6ba16 100644 --- a/test/gw/likelihood/marginalization_test.py +++ b/test/gw/likelihood/marginalization_test.py @@ -3,6 +3,7 @@ import pytest import unittest from copy import deepcopy +from functools import cached_property from itertools import product from parameterized import parameterized @@ -230,54 +231,63 @@ def setUp(self): maximum=self.parameters["geocent_time"] + 0.1 ) - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, + minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", - frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"), - frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"), ) ) - self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy" - self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" - self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.multiband_waveform_generator = bilby.gw.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, - minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", ) ) - self.multiband_waveform_generator = bilby.gw.WaveformGenerator( + @property + def roq_dir(self): + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) + for path in trial_roq_paths: + if os.path.isdir(path): + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + @property + def roq_linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def roq_quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @cached_property + def roq_waveform_generator(self): + return bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, waveform_approximant="IMRPhenomPv2", + frequency_nodes_linear=np.load(f"{self.roq_dir}/fnodes_linear.npy"), + frequency_nodes_quadratic=np.load(f"{self.roq_dir}/fnodes_quadratic.npy"), ) ) @@ -287,7 +297,6 @@ def tearDown(self): del self.parameters del self.interferometers del self.waveform_generator - del self.roq_waveform_generator del self.priors @classmethod diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 9d7a7e36f..683d07a3e 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,17 +1,61 @@ import os import unittest import tempfile +from functools import cached_property from itertools import product from parameterized import parameterized import pytest from copy import deepcopy +import array_api_compat as aac import h5py import numpy as np import bilby +from array_api_compat import is_array_api_obj from bilby.gw.likelihood import BilbyROQParamsRangeError +class BackendWaveformGenerator(bilby.gw.waveform_generator.WaveformGenerator): + """ + A thin wrapper to emulate different backends in the waveform generator. + + This ensures that all frequency arrays that might be used inside the + source are cast to numpy for compatibility. The outputs are converted + to the appropriate array type. + """ + def __init__(self, wfg, xp): + self.wfg = wfg + self.xp = xp + + def __getattr__(self, name): + if name == "xp": + return self.xp + return getattr(self.wfg, name) + + def convert_nested_dict(self, data): + if is_array_api_obj(data): + return self.xp.asarray(data) + elif isinstance(data, dict): + return {key: self.convert_nested_dict(value) for key, value in data.items()} + else: + raise ValueError("Input must be an array API object or a dict of such objects.") + + def _strain_from_model(self, model_data_points, model, parameters): + model_data_points = np.asarray(model_data_points) + return super()._strain_from_model(model_data_points, model, parameters) + + def frequency_domain_strain(self, parameters): + self.wfg.frequency_array = np.asarray(self.wfg.frequency_array) + if "frequency_nodes" in self.wfg.waveform_arguments: + self.wfg.waveform_arguments["frequency_nodes"] = np.asarray( + self.wfg.waveform_arguments["frequency_nodes"] + ) + wf = self.wfg.__class__.frequency_domain_strain(self, parameters) + return self.convert_nested_dict(wf) + + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) @@ -26,21 +70,23 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, + psi=self.xp.asarray(2.659), phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=2048, duration=4 + sampling_frequency=self.xp.asarray(2048.0), duration=self.xp.asarray(4.0) ) - self.waveform_generator = bilby.gw.waveform_generator.GWSignalWaveformGenerator( - duration=4, - sampling_frequency=2048, + self.interferometers.set_array_backend(self.xp) + base_wfg = bilby.gw.waveform_generator.GWSignalWaveformGenerator( + duration=self.xp.asarray(4.0), + sampling_frequency=self.xp.asarray(2048.0), waveform_arguments=dict(waveform_approximant="IMRPhenomPv2"), ) + self.waveform_generator = BackendWaveformGenerator(base_wfg, self.xp) self.likelihood = bilby.gw.likelihood.BasicGravitationalWaveTransient( interferometers=self.interferometers, @@ -55,23 +101,27 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, float(nll), 3 ) + self.assertEqual(aac.get_namespace(nll), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(self.likelihood.log_likelihood(self.parameters), -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(self.parameters), + llr, 3, ) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -86,11 +136,13 @@ def test_repr(self): self.assertEqual(expected, repr(self.likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) - self.duration = 4 - self.sampling_frequency = 2048 + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) self.parameters = dict( mass_1=31.0, mass_2=29.0, @@ -102,21 +154,23 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + psi=self.xp.asarray(2.659), + phase=self.xp.asarray(1.3), + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) - self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.interferometers.set_array_backend(self.xp) + wfg = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, ) + self.waveform_generator = BackendWaveformGenerator(wfg, self.xp) self.prior = bilby.gw.prior.BBHPriorDict() self.prior["geocent_time"] = bilby.prior.Uniform( @@ -139,24 +193,27 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() + self.assertEqual(aac.get_namespace(nll), self.xp) self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, float(nll), 3 ) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(self.likelihood.log_likelihood(self.parameters), - -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( - self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(self.parameters), + float(self.likelihood.log_likelihood(self.parameters)) - float(self.likelihood.noise_log_likelihood()), + float(llr), 3, ) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -236,14 +293,16 @@ def test_reference_frame_agrees_with_default(self): ) parameters = self.parameters.copy() del parameters["ra"], parameters["dec"] - parameters["zenith"] = 1.0 - parameters["azimuth"] = 1.0 + parameters["zenith"] = self.xp.asarray(1.0) + parameters["azimuth"] = self.xp.asarray(1.0) parameters["ra"], parameters["dec"] = bilby.gw.utils.zenith_azimuth_to_ra_dec( zenith=parameters["zenith"], azimuth=parameters["azimuth"], geocent_time=parameters["geocent_time"], - ifos=bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos=new_likelihood.reference_frame, ) + self.assertEqual(aac.get_namespace(parameters["ra"]), self.xp) + self.assertEqual(aac.get_namespace(parameters["dec"]), self.xp) self.assertEqual( new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) @@ -264,42 +323,39 @@ def test_time_reference_agrees_with_default(self): ) parameters = self.parameters.copy() parameters["H1_time"] = parameters["geocent_time"] + time_delay - self.assertEqual( + self.assertAlmostEqual( new_likelihood.log_likelihood_ratio(parameters), - self.likelihood.log_likelihood_ratio(parameters) + self.likelihood.log_likelihood_ratio(parameters), + 8, ) -@pytest.mark.requires_roqs -class TestROQLikelihood(unittest.TestCase): - def setUp(self): - self.duration = 4 - self.sampling_frequency = 2048 +class ROQBasisMixin: - # Possible locations for the ROQ: in the docker image, local, or on CIT + @property + def roq_dir(self): trial_roq_paths = [ "/roq_basis", os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), "/home/cbc/ROQ_data/IMRPhenomPv2/4s", ] - roq_dir = None + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) for path in trial_roq_paths: if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) - fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) - fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) +@pytest.mark.requires_roqs +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") +@pytest.mark.flaky(reruns=3) # pyfftw is flake on some machines +class TestROQLikelihood(ROQBasisMixin, unittest.TestCase): + def setUp(self): + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) + bilby.core.utils.random.seed(500) self.test_parameters = dict( mass_1=36.0, @@ -312,17 +368,18 @@ def setUp(self): phi_jl=0.3, luminosity_distance=1000.0, theta_jn=0.4, - psi=0.659, + psi=self.xp.asarray(0.659), phase=1.3, - geocent_time=1.2, - ra=1.3, - dec=-1.2, + geocent_time=self.xp.asarray(1.2), + ra=self.xp.asarray(1.3), + dec=self.xp.asarray(-1.2), ) ifos = bilby.gw.detector.InterferometerList(["H1"]) ifos.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) + ifos.set_array_backend(self.xp) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") @@ -342,6 +399,7 @@ def setUp(self): waveform_approximant="IMRPhenomPv2", ), ) + non_roq_wfg = BackendWaveformGenerator(non_roq_wfg, self.xp) ifos.inject_signal( parameters=self.test_parameters, waveform_generator=non_roq_wfg @@ -349,20 +407,6 @@ def setUp(self): self.ifos = ifos - roq_wfg = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.duration, - sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - waveform_arguments=dict( - frequency_nodes_linear=fnodes_linear, - frequency_nodes_quadratic=fnodes_quadratic, - reference_frequency=20.0, - waveform_approximant="IMRPhenomPv2", - ), - ) - - self.roq_wfg = roq_wfg - self.non_roq = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=ifos, waveform_generator=non_roq_wfg ) @@ -374,38 +418,71 @@ def setUp(self): priors=self.priors.copy(), ) - self.roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - priors=self.priors, - ) - - self.roq_phase = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - phase_marginalization=True, - priors=self.priors.copy(), - ) - def tearDown(self): del ( - self.roq, self.non_roq, self.non_roq_phase, - self.roq_phase, self.ifos, self.priors, ) + @property + def linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @property + def params_file(self): + return f"{self.roq_dir}/params.dat" + + @cached_property + def roq_wfg(self): + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" + fnodes_linear = np.load(fnodes_linear_file).T + fnodes_quadratic = np.load(fnodes_quadratic_file).T + wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + waveform_arguments=dict( + frequency_nodes_linear=fnodes_linear, + frequency_nodes_quadratic=fnodes_quadratic, + reference_frequency=20.0, + waveform_approximant="IMRPhenomPv2", + ), + ) + return BackendWaveformGenerator(wfg, self.xp) + + @cached_property + def roq(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + priors=self.priors, + ) + + @cached_property + def roq_phase(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + phase_marginalization=True, + priors=self.priors.copy(), + ) + def test_matches_non_roq(self): + roq_llr = self.roq.log_likelihood_ratio(self.test_parameters) self.assertLess( abs( - self.non_roq.log_likelihood_ratio(self.test_parameters) - - self.roq.log_likelihood_ratio(self.test_parameters) + self.non_roq.log_likelihood_ratio(self.test_parameters) - roq_llr ) / self.non_roq.log_likelihood_ratio(self.test_parameters), 1e-3, ) @@ -424,10 +501,12 @@ def test_create_roq_weights_with_params(self): quadratic_matrix=self.quadratic_matrix_file, priors=self.priors, ) + roq_llr = roq.log_likelihood_ratio(self.test_parameters) self.assertEqual( - roq.log_likelihood_ratio(self.test_parameters), + roq_llr, self.roq.log_likelihood_ratio(self.test_parameters) ) + self.assertEqual(aac.get_namespace(roq_llr), self.xp) def test_create_roq_weights_frequency_mismatch_works_with_params(self): @@ -537,33 +616,18 @@ def test_create_roq_weights_fails_due_to_duration(self): @pytest.mark.requires_roqs -class TestRescaledROQLikelihood(unittest.TestCase): +class TestRescaledROQLikelihood(unittest.TestCase, ROQBasisMixin): def test_rescaling(self): + linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" - # Possible locations for the ROQ: in the docker image, local, or on CIT - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) + self.linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + self.quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" + self.params_file = f"{self.roq_dir}/params.dat" scale_factor = 0.5 params = np.genfromtxt(self.params_file, names=True) @@ -611,7 +675,9 @@ def test_rescaling(self): @pytest.mark.requires_roqs -class TestROQLikelihoodHDF5(unittest.TestCase): +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") +class TestROQLikelihoodHDF5(unittest.TestCase, ROQBasisMixin): """ Test ROQ likelihood constructed from .hdf5 basis @@ -619,14 +685,13 @@ class TestROQLikelihoodHDF5(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun= self.priors["chirp_mass"].minimum) * @@ -843,13 +910,14 @@ def assertLess_likelihood_errors( self.priors["chirp_mass"].maximum = mc_max interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"]) + interferometers.set_array_backend(self.xp) for ifo in interferometers: if minimum_frequency is None: ifo.minimum_frequency = self.minimum_frequency else: - ifo.minimum_frequency = minimum_frequency + ifo.minimum_frequency = self.xp.asarray(minimum_frequency) if maximum_frequency is not None: - ifo.maximum_frequency = maximum_frequency + ifo.maximum_frequency = self.xp.asarray(maximum_frequency) interferometers.set_strain_data_from_zero_noise( sampling_frequency=self.sampling_frequency, duration=self.duration, @@ -884,6 +952,7 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + waveform_generator = BackendWaveformGenerator(waveform_generator, self.xp) interferometers.inject_signal(waveform_generator=waveform_generator, parameters=self.injection_parameters) likelihood = bilby.gw.GravitationalWaveTransient( @@ -901,12 +970,13 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + search_waveform_generator = BackendWaveformGenerator(search_waveform_generator, self.xp) likelihood_roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=interferometers, priors=self.priors, waveform_generator=search_waveform_generator, - linear_matrix=basis_linear, - quadratic_matrix=basis_quadratic, + linear_matrix=f"{self.roq_dir}/{basis_linear}", + quadratic_matrix=f"{self.roq_dir}/{basis_quadratic}", roq_scale_factor=roq_scale_factor ) for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11): @@ -915,10 +985,11 @@ def assertLess_likelihood_errors( llr = likelihood.log_likelihood_ratio(parameters) llr_roq = likelihood_roq.log_likelihood_ratio(parameters) self.assertLess(np.abs(llr - llr_roq), max_llr_error) + self.assertEqual(aac.get_namespace(llr_roq), self.xp) @pytest.mark.requires_roqs -class TestCreateROQLikelihood(unittest.TestCase): +class TestCreateROQLikelihood(unittest.TestCase, ROQBasisMixin): """ Test if ROQ likelihood is constructed without any errors from .hdf5 or .npy basis @@ -926,9 +997,8 @@ class TestCreateROQLikelihood(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun 0.001) + self.assertEqual(aac.get_namespace(samples), self.xp) if __name__ == "__main__": diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index cf78849c7..2fc700993 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -3,6 +3,7 @@ from shutil import rmtree from importlib.metadata import version +import array_api_compat as aac import numpy as np import lal import lalsimulation as lalsim @@ -15,6 +16,8 @@ from bilby.gw import utils as gwutils +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWUtils(unittest.TestCase): def setUp(self): self.outdir = "outdir" @@ -27,29 +30,36 @@ def tearDown(self): pass def test_asd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) + self.assertEqual(aac.get_namespace(asd), self.xp) + asd = np.asarray(asd) + freq_data = np.asarray(freq_data) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) def test_psd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) + self.assertEqual(aac.get_namespace(psd), self.xp) + psd = np.asarray(psd) + freq_data = np.asarray(freq_data) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) def test_inner_product(self): - aa = np.array([1, 2, 3]) - bb = np.array([5, 6, 7]) - frequency = np.array([0.2, 0.4, 0.6]) + aa = self.xp.asarray([1, 2, 3]) + bb = self.xp.asarray([5, 6, 7]) + frequency = self.xp.asarray([0.2, 0.4, 0.6]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() ip = gwutils.inner_product(aa, bb, frequency, PSD) self.assertEqual(ip, 0) + self.assertEqual(aac.get_namespace(ip), self.xp) def test_noise_weighted_inner_product(self): - aa = np.array([1e-23, 2e-23, 3e-23]) - bb = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + aa = self.xp.asarray([1e-23, 2e-23, 3e-23]) + bb = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -60,11 +70,12 @@ def test_noise_weighted_inner_product(self): gwutils.optimal_snr_squared(aa, psd, duration), gwutils.noise_weighted_inner_product(aa, aa, psd, duration), ) + self.assertEqual(aac.get_namespace(nwip), self.xp) def test_matched_filter_snr(self): - signal = np.array([1e-23, 2e-23, 3e-23]) - frequency_domain_strain = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + signal = self.xp.asarray([1e-23, 2e-23, 3e-23]) + frequency_domain_strain = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -73,6 +84,27 @@ def test_matched_filter_snr(self): signal, frequency_domain_strain, psd, duration ) self.assertEqual(mfsnr, 25.510869054168282) + self.assertEqual(aac.get_namespace(mfsnr), self.xp) + + def test_overlap(self): + signal = self.xp.linspace(1e-23, 21e-23, 21) + frequency_domain_strain = self.xp.linspace(5e-23, 25e-23, 21) + frequency = self.xp.linspace(100, 120, 21) + PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() + psd = PSD.power_spectral_density_interpolated(frequency) + duration = 4 + overlap = gwutils.overlap( + signal, + frequency_domain_strain, + psd, + delta_frequency=1 / duration, + lower_cut_off=3, + upper_cut_off=18, + norm_a=gwutils.optimal_snr_squared(signal, psd, duration), + norm_b=gwutils.optimal_snr_squared(frequency_domain_strain, psd, duration), + ) + self.assertEqual(aac.get_namespace(overlap), self.xp) + self.assertAlmostEqual(float(overlap), 2.76914407e-05) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): @@ -264,6 +296,8 @@ def test_safe_cast_mode_to_int(self): gwutils.safe_cast_mode_to_int(None) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSkyFrameConversion(unittest.TestCase): def setUp(self) -> None: @@ -281,23 +315,39 @@ def tearDown(self) -> None: del self.ifos del self.samples + def test_conversion_single(self) -> None: + sample = self.priors.sample() + zenith = self.xp.asarray(sample["zenith"]) + azimuth = self.xp.asarray(sample["azimuth"]) + time = self.xp.asarray(sample["time"]) + self.ifos.set_array_backend(self.xp) + ra, dec = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zenith, azimuth, time, self.ifos + ) + self.assertEqual(aac.get_namespace(ra), self.xp) + self.assertEqual(aac.get_namespace(dec), self.xp) + def test_conversion_gives_correct_prior(self) -> None: - zeniths = self.samples["zenith"] - azimuths = self.samples["azimuth"] - times = self.samples["time"] - args = zip(*[ - (zenith, azimuth, time, self.ifos) - for zenith, azimuth, time in zip(zeniths, azimuths, times) - ]) - ras, decs = zip(*map(bilby.gw.utils.zenith_azimuth_to_ra_dec, *args)) + zeniths = self.xp.asarray(self.samples["zenith"]) + azimuths = self.xp.asarray(self.samples["azimuth"]) + times = self.xp.asarray(self.samples["time"]) + self.ifos.set_array_backend(self.xp) + ras, decs = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zeniths, azimuths, times, self.ifos + ) + self.assertEqual(aac.get_namespace(ras), self.xp) + self.assertEqual(aac.get_namespace(decs), self.xp) + ras = np.asarray(ras) + decs = np.asarray(decs) self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) -def test_ln_i0_mathces_scipy(): +@pytest.mark.array_backend +def test_ln_i0_mathces_scipy(xp): from scipy.special import i0 - values = np.linspace(-10, 10, 101) - assert max(abs(gwutils.ln_i0(values) - np.log(i0(values)))) < 1e-10 + values = xp.linspace(-10, 10, 101) + assert max(abs(gwutils.ln_i0(values) - xp.log(i0(values)))) < 1e-10 if __name__ == "__main__": diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index f63b40537..70e48aa83 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -1,9 +1,12 @@ import unittest from unittest import mock +import array_api_compat as aac import bilby import lalsimulation import numpy as np +import pytest +from bilby.compat.utils import xp_wrap def dummy_func_array_return_value( @@ -36,16 +39,21 @@ def dummy_func_dict_return_value( return ht +@xp_wrap def dummy_func_array_return_value_2( - array, amplitude, mu, sigma, ra, dec, geocent_time, psi + array, amplitude, mu, sigma, ra, dec, geocent_time, psi, *, xp=None ): - return dict(plus=np.array(array), cross=np.array(array)) + return dict(plus=xp.asarray(array), cross=xp.asarray(array)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, frequency_domain_source_model=dummy_func_dict_return_value + self.xp.asarray(1.0), + self.xp.asarray(4096.0), + frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -118,9 +126,11 @@ def conversion_func(): def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) + self.assertEqual(aac.get_namespace(self.waveform_generator.duration), self.xp) def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) + self.assertEqual(aac.get_namespace(self.waveform_generator.sampling_frequency), self.xp) def test_source_model(self): self.assertEqual( @@ -129,10 +139,10 @@ def test_source_model(self): ) def test_frequency_array_type(self): - self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.waveform_generator.time_array, np.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.time_array), self.xp) def test_source_model_parameters(self): self.waveform_generator.parameters = self.simulation_parameters.copy() @@ -301,11 +311,13 @@ def conversion_func(): self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -347,6 +359,8 @@ def test_frequency_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None @@ -364,6 +378,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None @@ -382,6 +397,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -491,8 +508,8 @@ def test_frequency_domain_caching_changing_model(self): def test_time_domain_caching_changing_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), time_domain_source_model=dummy_func_dict_return_value, ) original_waveform = self.waveform_generator.frequency_domain_strain( @@ -507,12 +524,18 @@ def test_time_domain_caching_changing_model(self): self.assertFalse( np.array_equal(original_waveform["plus"], new_waveform["plus"]) ) + self.assertEqual(aac.get_namespace(new_waveform["plus"]), self.xp) + self.assertEqual(aac.get_namespace(new_waveform["cross"]), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, time_domain_source_model=dummy_func_dict_return_value + self.xp.asarray(1.0), + self.xp.asarray(4096.0), + time_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -553,6 +576,8 @@ def test_time_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None @@ -572,6 +597,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None @@ -592,6 +618,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None