From 5faa4329106163c1737bc6fbd034b1da82efc3a3 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 25 Oct 2024 16:10:32 +0000 Subject: [PATCH 001/110] FEAT: enable backend switching for base gravitational-wave transient likelihood --- .github/workflows/basic-install.yml | 13 +- bilby/core/utils/series.py | 13 +- bilby/gw/detector/geometry.py | 20 +- bilby/gw/detector/interferometer.py | 29 ++- bilby/gw/detector/networks.py | 4 + bilby/gw/jaxstuff.py | 113 ++++++++++++ bilby/gw/likelihood/base.py | 12 +- bilby/gw/utils.py | 25 +-- bilby/gw/waveform_generator.py | 20 +- .../injection_examples/jax_fast_tutorial.py | 173 ++++++++++++++++++ 10 files changed, 374 insertions(+), 48 deletions(-) create mode 100644 bilby/gw/jaxstuff.py create mode 100644 examples/gw_examples/injection_examples/jax_fast_tutorial.py diff --git a/.github/workflows/basic-install.yml b/.github/workflows/basic-install.yml index a71bfc18e..976fe9255 100644 --- a/.github/workflows/basic-install.yml +++ b/.github/workflows/basic-install.yml @@ -20,8 +20,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - # disable windows build test as bilby_cython is currently broken there - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -52,8 +51,8 @@ jobs: python -c "import bilby.hyper" python -c "import cli_bilby" python test/import_test.py - # - if: ${{ matrix.os != "windows-latest" }} - # run: | - # for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do - # ${script} --help; - # done + - if: ${{ matrix.os != "windows-latest" }} + run: | + for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do + ${script} --help; + done diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index 63daebd6e..c3d71e3e2 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -1,4 +1,5 @@ import numpy as np +from bilback.utils import array_module _TOL = 14 @@ -97,9 +98,10 @@ def create_time_series(sampling_frequency, duration, starting_time=0.): float: An equidistant time series given the parameters """ + xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) number_of_samples = int(duration * sampling_frequency) - return np.linspace(start=starting_time, + return xp.linspace(start=starting_time, stop=duration + starting_time - 1 / sampling_frequency, num=number_of_samples) @@ -117,11 +119,12 @@ def create_frequency_series(sampling_frequency, duration): array_like: frequency series """ + xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) - number_of_samples = int(np.round(duration * sampling_frequency)) - number_of_frequencies = int(np.round(number_of_samples / 2) + 1) + number_of_samples = int(xp.round(duration * sampling_frequency)) + number_of_frequencies = int(xp.round(number_of_samples / 2) + 1) - return np.linspace(start=0, + return xp.linspace(start=0, stop=sampling_frequency / 2, num=number_of_frequencies) @@ -139,7 +142,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > 10**(-_TOL): + if abs(num % 1) > 10**(-_TOL): raise IllegalDurationAndSamplingFrequencyException( '\nYour sampling frequency and duration must multiply to a number' 'up to (tol = {}) decimals close to an integer number. ' diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index d7e1433de..ea2a509ab 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -1,5 +1,5 @@ import numpy as np -from bilby_cython.geometry import calculate_arm, detector_tensor +from bilback.geometry import calculate_arm, detector_tensor from .. import utils as gwutils @@ -264,7 +264,7 @@ def detector_tensor(self): if not self._x_updated or not self._y_updated: _, _ = self.x, self.y # noqa if not self._detector_tensor_updated: - self._detector_tensor = detector_tensor(x=self.x, y=self.y) + self._detector_tensor = detector_tensor(self.x, self.y) self._detector_tensor_updated = True return self._detector_tensor @@ -290,17 +290,17 @@ def unit_vector_along_arm(self, arm): """ if arm == 'x': return calculate_arm( - arm_tilt=self._xarm_tilt, - arm_azimuth=self._xarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._xarm_tilt, + self._xarm_azimuth, + self._longitude, + self._latitude ) elif arm == 'y': return calculate_arm( - arm_tilt=self._yarm_tilt, - arm_azimuth=self._yarm_azimuth, - longitude=self._longitude, - latitude=self._latitude + self._yarm_tilt, + self._yarm_azimuth, + self._longitude, + self._latitude ) else: raise ValueError("Arm must either be 'x' or 'y'.") diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 2fca163e0..29f52dcd4 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -1,7 +1,7 @@ import os import numpy as np -from bilby_cython.geometry import ( +from bilback.geometry import ( get_polarization_tensor, three_by_three_matrix_contraction, time_delay_from_geocenter, @@ -305,7 +305,8 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= array_like: A 3x3 array representation of the detector response (signal observed in the interferometer) """ if frequencies is None: - frequencies = self.frequency_array[self.frequency_mask] + # frequencies = self.frequency_array[self.frequency_mask] + frequencies = self.frequency_array mask = self.frequency_mask else: mask = np.ones(len(frequencies), dtype=bool) @@ -318,8 +319,8 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= parameters['geocent_time'], parameters['psi'], mode) - signal[mode] = waveform_polarizations[mode] * det_response - signal_ifo = sum(signal.values()) * mask + signal[mode] = waveform_polarizations[mode] * mask * det_response + signal_ifo = sum(signal.values()) time_shift = self.time_delay_from_geocenter( parameters['ra'], parameters['dec'], parameters['geocent_time']) @@ -329,9 +330,12 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= dt_geocent = parameters['geocent_time'] - self.strain_data.start_time dt = dt_geocent + time_shift - signal_ifo[mask] = signal_ifo[mask] * np.exp(-1j * 2 * np.pi * dt * frequencies) + from bilback.utils import array_module + xp = array_module(signal_ifo) - signal_ifo[mask] *= self.calibration_model.get_calibration_factor( + signal_ifo = signal_ifo * xp.exp(-1j * 2 * np.pi * dt * frequencies) + + signal_ifo *= self.calibration_model.get_calibration_factor( frequencies, prefix='recalib_{}_'.format(self.name), **parameters ) @@ -923,3 +927,16 @@ def from_pickle(cls, filename=None): if res.__class__ != cls: raise TypeError('The loaded object is not an Interferometer') return res + + def set_array_backend(self, xp): + for attr in [ + "length", + "latitude", + "longitude", + "elevation", + "xarm_azimuth", + "yarm_azimuth", + "xarm_tilt", + "yarm_tilt", + ]: + setattr(self, attr, xp.array(getattr(self, attr))) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 4efd3d8db..b686f2fd5 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -331,6 +331,10 @@ def from_pickle(cls, filename=None): ) from_pickle.__doc__ = _load_docstring.format(format="pickle") + def set_array_backend(self, xp): + for ifo in self: + ifo.set_array_backend(xp) + class TriangularInterferometer(InterferometerList): def __init__( diff --git a/bilby/gw/jaxstuff.py b/bilby/gw/jaxstuff.py new file mode 100644 index 000000000..f02df34a3 --- /dev/null +++ b/bilby/gw/jaxstuff.py @@ -0,0 +1,113 @@ +""" +Generic dumping ground for jax-specific functions that we need. +This should find a home somewhere down the line, but gives an +idea of how much pain is being added. +""" + +from functools import partial + +import numpy as np +from bilby.core.likelihood import Likelihood + +import jax +import jax.numpy as jnp +from plum import dispatch +from jax.scipy.special import i0e +from ripple.waveforms import IMRPhenomPv2 + +def bilby_to_ripple_spins( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + a_1, + a_2, +): + iota = theta_jn + spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) + spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) + spin_1z = a_1 * jnp.cos(tilt_1) + spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) + spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) + spin_2z = a_2 * jnp.cos(tilt_2) + return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z + + +def ripple_bbh(frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, + a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs): + iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 + ) + theta = jnp.array([ + mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, + luminosity_distance, 0.0, phase, iota + ]) + hp, hc = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2)(frequency, theta, jax.numpy.array(20.0)) + return dict(plus=hp, cross=hc) + + +def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): + """ + A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to evaluate. + parameters: dict + The parameters to evaluate the likelihood at. + use_ratio: bool, optional + Whether to evaluate the likelihood ratio or the full likelihood. + Default is :code:`True`. + """ + likelihood.parameters.update(parameters) + if use_ratio: + return likelihood.log_likelihood_ratio() + else: + return likelihood.log_likelihood() + + +class JittedLikelihood(Likelihood): + """ + A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. + + .. note:: + + This is currently hardcoded to return the log likelihood ratio, regardless of + the input. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to wrap. + likelihood_func: callable, optional + The function to use to evaluate the likelihood. Default is + :code:`generic_bilby_likelihood_function`. This function should take the + likelihood and parameters as arguments along with additional keyword arguments. + kwargs: dict, optional + Additional keyword arguments to pass to the likelihood function. + """ + + def __init__( + self, likelihood, likelihood_func=generic_bilby_likelihood_function, kwargs=None + ): + if kwargs is None: + kwargs = dict() + self.kwargs = kwargs + self._likelihood = likelihood + self.likelihood_func = jax.jit(partial(likelihood_func, likelihood)) + super().__init__(dict()) + + def __getattr__(self, name): + return getattr(self._likelihood, name) + + def log_likelihood_ratio(self): + return float( + np.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) + ) + + +@dispatch +def ln_i0(value: jax.Array): + return jnp.log(i0e(value)) + jnp.abs(value) \ No newline at end of file diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index c66e05a34..865944ed0 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -107,9 +107,13 @@ class GravitationalWaveTransient(Likelihood): @attr.s(slots=True, weakref_slot=False) class _CalculatedSNRs: - d_inner_h = attr.ib(default=0j, converter=complex) - optimal_snr_squared = attr.ib(default=0, converter=float) - complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + # the complex converted breaks JAX compilation + # d_inner_h = attr.ib(default=0j, converter=complex) + # optimal_snr_squared = attr.ib(default=0, converter=float) + # complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + d_inner_h = attr.ib(default=0j) + optimal_snr_squared = attr.ib(default=0) + complex_matched_filter_snr = attr.ib(default=0j) d_inner_h_array = attr.ib(default=None) optimal_snr_squared_array = attr.ib(default=None) @@ -436,7 +440,7 @@ def log_likelihood_ratio(self, parameters=None): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] -= parameters['time_jitter'] - return float(log_l.real) + return log_l.real def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): parameters = _fallback_to_parameters(self, parameters) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index f1f4c0291..8ee65e7fe 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,10 +5,12 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from bilby_cython.geometry import ( +from bilback.geometry import ( zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi, ) -from bilby_cython.time import greenwich_mean_sidereal_time +from bilback.time import greenwich_mean_sidereal_time +from bilback.utils import array_module +from plum import dispatch from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, @@ -76,14 +78,15 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): array_like: A 3D representation of the geocentric vertex position """ + xp = array_module(latitude) semi_major_axis = 6378137 # for ellipsoid model of Earth, in m semi_minor_axis = 6356752.314 # in m - radius = semi_major_axis**2 * (semi_major_axis**2 * np.cos(latitude)**2 + - semi_minor_axis**2 * np.sin(latitude)**2)**(-0.5) - x_comp = (radius + elevation) * np.cos(latitude) * np.cos(longitude) - y_comp = (radius + elevation) * np.cos(latitude) * np.sin(longitude) - z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * np.sin(latitude) - return np.array([x_comp, y_comp, z_comp]) + radius = semi_major_axis**2 * (semi_major_axis**2 * xp.cos(latitude)**2 + + semi_minor_axis**2 * xp.sin(latitude)**2)**(-0.5) + x_comp = (radius + elevation) * xp.cos(latitude) * xp.cos(longitude) + y_comp = (radius + elevation) * xp.cos(latitude) * xp.sin(longitude) + z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * xp.sin(latitude) + return xp.array([x_comp, y_comp, z_comp]) def inner_product(aa, bb, frequency, PSD): @@ -132,9 +135,8 @@ def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): ======= Noise-weighted inner product. """ - - integrand = np.conj(aa) * bb / power_spectral_density - return 4 / duration * np.sum(integrand) + integrand = aa.conjugate() * bb / power_spectral_density + return 4 / duration * integrand.sum() def matched_filter_snr(signal, frequency_domain_strain, power_spectral_density, duration): @@ -1023,6 +1025,7 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= plt.xlim(freq_points.min() - .5, freq_points.max() + 50) +@dispatch def ln_i0(value): """ A numerically stable method to evaluate ln(I_0) a modified Bessel function diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 5b78a6b6a..3983833a6 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -24,7 +24,8 @@ class WaveformGenerator(object): def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, time_domain_source_model=None, parameters=None, parameter_conversion=None, - waveform_arguments=None): + waveform_arguments=None, use_cache=True, + ): """ The base waveform generator class. @@ -57,6 +58,10 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen Note: the arguments of frequency_domain_source_model (except the first, which is the frequencies at which to compute the strain) will be added to the WaveformGenerator object and initialised to `None`. + use_cache: bool + Whether to attempt caching the waveform between subsequent calls. + This is :code:`True` by default but must be disabled for JIT compilation + with :code:`JAX`. """ self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, @@ -76,6 +81,7 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen if isinstance(parameters, dict): self.parameters = parameters self._cache = dict(parameters=None, waveform=None, model=None) + self.use_cache = use_cache logger.info(f"Waveform generator instantiated: {self}") def __repr__(self): @@ -160,10 +166,14 @@ def time_domain_strain(self, parameters=None): def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, transformed_model_data_points, parameters): - if parameters is None: - parameters = self.parameters - if parameters == self._cache['parameters'] and self._cache['model'] == model and \ - self._cache['transformed_model'] == transformed_model: + if parameters is not None: + self.parameters = parameters + if ( + self.use_cache + and self.parameters == self._cache['parameters'] + and self._cache['model'] == model + and self._cache['transformed_model'] == transformed_model + ): return self._cache['waveform'] else: self._cache['parameters'] = parameters.copy() diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py new file mode 100644 index 000000000..b4acfd333 --- /dev/null +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +""" +Tutorial to demonstrate running parameter estimation on a reduced parameter +space for an injected signal. + +This example estimates the masses using a uniform prior in both component masses +and distance using a uniform in comoving volume prior on luminosity distance +between luminosity distances of 100Mpc and 5Gpc, the cosmology is Planck15. + +We optionally use ripple waveforms and a JIT-compiled likelihood. +""" + +import bilby +import bilby.gw.jaxstuff +import jax +jax.config.update("jax_enable_x64", True) + +USE_JAX = True + +# Set the duration and sampling frequency of the data segment that we're +# going to inject the signal into +duration = 4.0 +sampling_frequency = 2048.0 +minimum_frequency = 20.0 +if USE_JAX: + duration = jax.numpy.array(duration) + sampling_frequency = jax.numpy.array(sampling_frequency) + minimum_frequency = jax.numpy.array(minimum_frequency) + +# Specify the output directory and the name of the simulation. +outdir = "outdir" +label = "fast_tutorial" +bilby.core.utils.setup_logger(outdir=outdir, label=label) + +# Set up a random seed for result reproducibility. This is optional! +bilby.core.utils.random.seed(88170235) + +# We are going to inject a binary black hole waveform. We first establish a +# dictionary of parameters that includes all of the different waveform +# parameters, including masses of the two black holes (mass_1, mass_2), +# spins of both black holes (a, tilt, phi), etc. +injection_parameters = dict( + mass_1=36.0, + mass_2=29.0, + a_1=0.4, + a_2=0.3, + tilt_1=0.5, + tilt_2=1.0, + phi_12=1.7, + phi_jl=0.3, + luminosity_distance=2000.0, + theta_jn=0.4, + psi=2.659, + phase=1.3, + geocent_time=1126259642.413, + ra=1.375, + dec=-1.2108, +) + +# Fixed arguments passed into the source model +waveform_arguments = dict( + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=minimum_frequency, +) + +if USE_JAX: + fdsm = bilby.gw.jaxstuff.ripple_bbh +else: + fdsm = bilby.gw.source.lal_binary_black_hole + +# Create the waveform_generator using a LAL BinaryBlackHole source function +waveform_generator = bilby.gw.WaveformGenerator( + duration=duration, + sampling_frequency=sampling_frequency, + frequency_domain_source_model=fdsm, + parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, + waveform_arguments=waveform_arguments, + use_cache=False, +) + +# Set up interferometers. In this case we'll use two interferometers +# (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design +# sensitivity +ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) +ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=sampling_frequency, + duration=duration, + start_time=injection_parameters["geocent_time"] - 2, +) +ifos.inject_signal( + waveform_generator=waveform_generator, parameters=injection_parameters +) +if USE_JAX: + ifos.set_array_backend(jax.numpy) + +# Set up a PriorDict, which inherits from dict. +# By default we will sample all terms in the signal models. However, this will +# take a long time for the calculation, so for this example we will set almost +# all of the priors to be equall to their injected values. This implies the +# prior is a delta function at the true, injected value. In reality, the +# sampler implementation is smart enough to not sample any parameter that has +# a delta-function prior. +# The above list does *not* include mass_1, mass_2, theta_jn and luminosity +# distance, which means those are the parameters that will be included in the +# sampler. If we do nothing, then the default priors get used. +priors = bilby.gw.prior.BBHPriorDict() +for key in [ + "a_1", + "a_2", + "tilt_1", + "tilt_2", + "phi_12", + "phi_jl", + "psi", + "ra", + "dec", + "geocent_time", +]: + priors[key] = injection_parameters[key] + +# Perform a check that the prior does not extend to a parameter space longer than the data +if not USE_JAX: + priors.validate_prior(duration, minimum_frequency) + +# Initialise the likelihood by passing in the interferometer data (ifos) and +# the waveform generator +likelihood = bilby.gw.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=waveform_generator, + priors=priors, + phase_marginalization=True, +) + +if USE_JAX: + # burn a few likelihood calls to check that we don't get + # repeated compilation + likelihood.parameters.update(priors.sample()) + likelihood.log_likelihood_ratio() + likelihood.log_likelihood() + likelihood.noise_log_likelihood() + + with jax.log_compiles(): + jit_likelihood = bilby.gw.jaxstuff.JittedLikelihood(likelihood) + jit_likelihood.parameters.update(priors.sample()) + jit_likelihood.log_likelihood_ratio() + jit_likelihood.log_likelihood() + jit_likelihood.noise_log_likelihood() + jit_likelihood.parameters.update(priors.sample()) + jit_likelihood.log_likelihood_ratio() + jit_likelihood.log_likelihood() + jit_likelihood.noise_log_likelihood() + sample_likelihood = jit_likelihood +else: + sample_likelihood = likelihood + +# use the log_compiles context so we can make sure there aren't recompilations +# inside the sampling loop +with jax.log_compiles(): + result = bilby.run_sampler( + likelihood=sample_likelihood, + priors=priors, + sampler="dynesty", + npoints=100, + sample="acceptance-walk", + naccept=10, + injection_parameters=injection_parameters, + outdir=outdir, + label=label, + ) + +# Make a corner plot. +result.plot_corner() From 208f2272a6a13d9312ddbb1f8a24c21abe0a9963 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 25 Oct 2024 17:24:06 -0500 Subject: [PATCH 002/110] FEAT: support multiband and relative binning likelihoods --- bilby/gw/detector/interferometer.py | 5 +- bilby/gw/jaxstuff.py | 51 ++- bilby/gw/likelihood/multiband.py | 50 ++- bilby/gw/likelihood/relative.py | 10 +- .../injection_examples/jax_fast_tutorial.py | 384 +++++++++++------- 5 files changed, 305 insertions(+), 195 deletions(-) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 29f52dcd4..2c450cff2 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -6,6 +6,7 @@ three_by_three_matrix_contraction, time_delay_from_geocenter, ) +from bilback.utils import array_module from ...core import utils from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump @@ -309,7 +310,8 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= frequencies = self.frequency_array mask = self.frequency_mask else: - mask = np.ones(len(frequencies), dtype=bool) + xp = array_module(frequencies) + mask = xp.ones(len(frequencies), dtype=bool) signal = {} for mode in waveform_polarizations.keys(): @@ -330,7 +332,6 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= dt_geocent = parameters['geocent_time'] - self.strain_data.start_time dt = dt_geocent + time_shift - from bilback.utils import array_module xp = array_module(signal_ifo) signal_ifo = signal_ifo * xp.exp(-1j * 2 * np.pi * dt * frequencies) diff --git a/bilby/gw/jaxstuff.py b/bilby/gw/jaxstuff.py index f02df34a3..464a7abb7 100644 --- a/bilby/gw/jaxstuff.py +++ b/bilby/gw/jaxstuff.py @@ -6,7 +6,6 @@ from functools import partial -import numpy as np from bilby.core.likelihood import Likelihood import jax @@ -15,6 +14,7 @@ from jax.scipy.special import i0e from ripple.waveforms import IMRPhenomPv2 + def bilby_to_ripple_spins( theta_jn, phi_jl, @@ -34,16 +34,41 @@ def bilby_to_ripple_spins( return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z -def ripple_bbh(frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs): +wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) + + +def ripple_bbh_relbin( + frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, + a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, fiducial, **kwargs, +): + if fiducial == 1: + kwargs["frequencies"] = frequency + else: + kwargs["frequencies"] = kwargs.pop("frequency_bin_edges") + return ripple_bbh( + frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, + a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs + ) + + +def ripple_bbh( + frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, + a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs, +): iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins( theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 ) + if "frequencies" in kwargs: + frequencies = kwargs["frequencies"] + elif "minimum_frequency" in kwargs: + frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) + else: + frequencies = frequency theta = jnp.array([ mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, 0.0, phase, iota + luminosity_distance, jnp.array(0.0), phase, iota ]) - hp, hc = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2)(frequency, theta, jax.numpy.array(20.0)) + hp, hc = wf_func(frequencies, theta, jax.numpy.array(20.0)) return dict(plus=hp, cross=hc) @@ -90,24 +115,30 @@ class JittedLikelihood(Likelihood): """ def __init__( - self, likelihood, likelihood_func=generic_bilby_likelihood_function, kwargs=None + self, + likelihood, + likelihood_func=generic_bilby_likelihood_function, + kwargs=None, + cast_to_float=True, ): if kwargs is None: kwargs = dict() self.kwargs = kwargs self._likelihood = likelihood self.likelihood_func = jax.jit(partial(likelihood_func, likelihood)) + self.cast_to_float = cast_to_float super().__init__(dict()) def __getattr__(self, name): return getattr(self._likelihood, name) def log_likelihood_ratio(self): - return float( - np.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) - ) + ln_l = jnp.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l @dispatch def ln_i0(value: jax.Array): - return jnp.log(i0e(value)) + jnp.abs(value) \ No newline at end of file + return jnp.log(i0e(value)) + jnp.abs(value) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index cc7b8e386..e2df98fb2 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -3,6 +3,7 @@ import numbers import numpy as np +from bilback.utils import array_module from .base import GravitationalWaveTransient from ...core.utils import ( @@ -748,32 +749,27 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr """ parameters = _fallback_to_parameters(self, parameters) if self.time_marginalization: - time_ref = self._beam_pattern_reference_time - else: - time_ref = parameters['geocent_time'] - - strain = np.zeros(len(self.banded_frequency_points), dtype=complex) - for mode in waveform_polarizations: - response = interferometer.antenna_response( - parameters['ra'], parameters['dec'], - time_ref, parameters['psi'], mode - ) - strain += waveform_polarizations[mode][self.unique_to_original_frequencies] * response + original_time = parameters["geocent_time"] + parameters["geocent_time"] = self._beam_pattern_reference_time + + modes = { + mode: value[self.unique_to_original_frequencies] + for mode, value in waveform_polarizations.items() + } + strain = interferometer.get_detector_response( + modes, parameters, frequencies=self.banded_frequency_points + ) - dt = interferometer.time_delay_from_geocenter( - parameters['ra'], parameters['dec'], time_ref) - dt_geocent = parameters['geocent_time'] - interferometer.strain_data.start_time - ifo_time = dt_geocent + dt - strain *= np.exp(-1j * 2. * np.pi * self.banded_frequency_points * ifo_time) + if self.time_marginalization: + parameters["geocent_time"] = origianl_time - strain *= interferometer.calibration_model.get_calibration_factor( - self.banded_frequency_points, prefix='recalib_{}_'.format(interferometer.name), **parameters) + d_inner_h = (strain @ self.linear_coeffs[interferometer.name]).conjugate() - d_inner_h = np.conj(np.dot(strain, self.linear_coeffs[interferometer.name])) + xp = array_module(strain) if self.linear_interpolation: - optimal_snr_squared = np.vdot( - np.real(strain * np.conjugate(strain)), + optimal_snr_squared = xp.vdot( + (strain * strain.conjugate()).real, self.quadratic_coeffs[interferometer.name] ) else: @@ -783,18 +779,18 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr start_idx, end_idx = self.start_end_idxs[b] Mb = self.Mbs[b] if b == 0: - optimal_snr_squared += (4. / self.interferometers.duration) * np.vdot( - np.real(strain[start_idx:end_idx + 1] * np.conjugate(strain[start_idx:end_idx + 1])), + optimal_snr_squared += (4. / self.interferometers.duration) * xp.vdot( + (strain[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1].conjugate()).real, interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] / interferometer.power_spectral_density_array[Ks:Ke + 1]) else: self.wths[interferometer.name][b][Ks:Ke + 1] = ( self.square_root_windows[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1] ) - self.hbcs[interferometer.name][b][-Mb:] = np.fft.irfft(self.wths[interferometer.name][b]) - thbc = np.fft.rfft(self.hbcs[interferometer.name][b]) - optimal_snr_squared += (4. / self.Tbhats[b]) * np.vdot( - np.real(thbc * np.conjugate(thbc)), self.Ibcs[interferometer.name][b]) + self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft(self.wths[interferometer.name][b]) + thbc = xp.fft.rfft(self.hbcs[interferometer.name][b]) + optimal_snr_squared += (4. / self.Tbhats[b]) * xp.vdot( + thbc * np.conjugate(thbc).real, self.Ibcs[interferometer.name][b]) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index f4c72e8ef..199e0e439 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -2,6 +2,7 @@ import numpy as np from scipy.optimize import differential_evolution +from bilback.utils import array_module from .base import GravitationalWaveTransient from ...core.utils import logger @@ -253,7 +254,7 @@ def set_fiducial_waveforms(self, parameters): for interferometer in self.interferometers: logger.debug(f"Maximum Frequency is {interferometer.maximum_frequency}") wf = interferometer.get_detector_response(self.fiducial_polarizations, parameters) - wf[interferometer.frequency_array > self.maximum_frequency] = 0 + wf *= interferometer.frequency_array <= self.maximum_frequency self.per_detector_fiducial_waveforms[interferometer.name] = wf def find_maximum_likelihood_parameters(self, parameter_bounds, @@ -397,18 +398,19 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr parameters=parameters, ) a0, a1, b0, b1 = self.summary_data[interferometer.name] - d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1)) - h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1))) + d_inner_h = (a0 * r0.conjugate() + a1 * r1.conjugate()).sum() + h_inner_h = (b0 * abs(r0) ** 2 + 2 * b1 * (r0 * r1.conjugate()).real).sum() optimal_snr_squared = h_inner_h complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) if return_array and self.time_marginalization: + xp = array_module(r0) full_waveform = self._compute_full_waveform( signal_polarizations=waveform_polarizations, interferometer=interferometer, parameters=parameters, ) - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( full_waveform[0:-1] * interferometer.frequency_domain_strain.conjugate()[0:-1] / interferometer.power_spectral_density_array[0:-1]) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index b4acfd333..930cc6de5 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -9,165 +9,245 @@ We optionally use ripple waveforms and a JIT-compiled likelihood. """ +import os +from itertools import product + +# Set OMP_NUM_THREADS to stop lalsimulation taking over my computer +os.environ["OMP_NUM_THREADS"] = "1" import bilby import bilby.gw.jaxstuff import jax +from numpyro.infer import AIES, ESS # noqa + jax.config.update("jax_enable_x64", True) -USE_JAX = True - -# Set the duration and sampling frequency of the data segment that we're -# going to inject the signal into -duration = 4.0 -sampling_frequency = 2048.0 -minimum_frequency = 20.0 -if USE_JAX: - duration = jax.numpy.array(duration) - sampling_frequency = jax.numpy.array(sampling_frequency) - minimum_frequency = jax.numpy.array(minimum_frequency) - -# Specify the output directory and the name of the simulation. -outdir = "outdir" -label = "fast_tutorial" -bilby.core.utils.setup_logger(outdir=outdir, label=label) - -# Set up a random seed for result reproducibility. This is optional! -bilby.core.utils.random.seed(88170235) - -# We are going to inject a binary black hole waveform. We first establish a -# dictionary of parameters that includes all of the different waveform -# parameters, including masses of the two black holes (mass_1, mass_2), -# spins of both black holes (a, tilt, phi), etc. -injection_parameters = dict( - mass_1=36.0, - mass_2=29.0, - a_1=0.4, - a_2=0.3, - tilt_1=0.5, - tilt_2=1.0, - phi_12=1.7, - phi_jl=0.3, - luminosity_distance=2000.0, - theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, -) - -# Fixed arguments passed into the source model -waveform_arguments = dict( - waveform_approximant="IMRPhenomPv2", - reference_frequency=50.0, - minimum_frequency=minimum_frequency, -) - -if USE_JAX: - fdsm = bilby.gw.jaxstuff.ripple_bbh -else: - fdsm = bilby.gw.source.lal_binary_black_hole - -# Create the waveform_generator using a LAL BinaryBlackHole source function -waveform_generator = bilby.gw.WaveformGenerator( - duration=duration, - sampling_frequency=sampling_frequency, - frequency_domain_source_model=fdsm, - parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, - waveform_arguments=waveform_arguments, - use_cache=False, -) - -# Set up interferometers. In this case we'll use two interferometers -# (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design -# sensitivity -ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) -ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=sampling_frequency, - duration=duration, - start_time=injection_parameters["geocent_time"] - 2, -) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) -if USE_JAX: - ifos.set_array_backend(jax.numpy) - -# Set up a PriorDict, which inherits from dict. -# By default we will sample all terms in the signal models. However, this will -# take a long time for the calculation, so for this example we will set almost -# all of the priors to be equall to their injected values. This implies the -# prior is a delta function at the true, injected value. In reality, the -# sampler implementation is smart enough to not sample any parameter that has -# a delta-function prior. -# The above list does *not* include mass_1, mass_2, theta_jn and luminosity -# distance, which means those are the parameters that will be included in the -# sampler. If we do nothing, then the default priors get used. -priors = bilby.gw.prior.BBHPriorDict() -for key in [ - "a_1", - "a_2", - "tilt_1", - "tilt_2", - "phi_12", - "phi_jl", - "psi", - "ra", - "dec", - "geocent_time", -]: - priors[key] = injection_parameters[key] - -# Perform a check that the prior does not extend to a parameter space longer than the data -if not USE_JAX: - priors.validate_prior(duration, minimum_frequency) - -# Initialise the likelihood by passing in the interferometer data (ifos) and -# the waveform generator -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, - waveform_generator=waveform_generator, - priors=priors, - phase_marginalization=True, -) - -if USE_JAX: - # burn a few likelihood calls to check that we don't get - # repeated compilation - likelihood.parameters.update(priors.sample()) - likelihood.log_likelihood_ratio() - likelihood.log_likelihood() - likelihood.noise_log_likelihood() +bilby.core.utils.setup_logger(log_level="WARNING") + + +def main(use_jax, model): + # Set the duration and sampling frequency of the data segment that we're + # going to inject the signal into + duration = 4.0 + sampling_frequency = 2048.0 + minimum_frequency = 20.0 + if use_jax: + duration = jax.numpy.array(duration) + sampling_frequency = jax.numpy.array(sampling_frequency) + minimum_frequency = jax.numpy.array(minimum_frequency) + + # Specify the output directory and the name of the simulation. + outdir = "outdir" + label = f"{model}_{'jax' if use_jax else 'numpy'}" + + # Set up a random seed for result reproducibility. This is optional! + bilby.core.utils.random.seed(88170235) + + # We are going to inject a binary black hole waveform. We first establish a + # dictionary of parameters that includes all of the different waveform + # parameters, including masses of the two black holes (mass_1, mass_2), + # spins of both black holes (a, tilt, phi), etc. + injection_parameters = dict( + mass_1=36.0, + mass_2=29.0, + a_1=0.4, + a_2=0.3, + tilt_1=0.5, + tilt_2=1.0, + phi_12=1.7, + phi_jl=0.3, + luminosity_distance=2000.0, + theta_jn=0.4, + psi=2.659, + phase=1.3, + geocent_time=1126259642.413, + ra=1.375, + dec=-1.2108, + ) + if model == "relbin": + injection_parameters["fiducial"] = 1 + + # Fixed arguments passed into the source model + waveform_arguments = dict( + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=minimum_frequency, + ) - with jax.log_compiles(): - jit_likelihood = bilby.gw.jaxstuff.JittedLikelihood(likelihood) - jit_likelihood.parameters.update(priors.sample()) - jit_likelihood.log_likelihood_ratio() - jit_likelihood.log_likelihood() - jit_likelihood.noise_log_likelihood() - jit_likelihood.parameters.update(priors.sample()) - jit_likelihood.log_likelihood_ratio() - jit_likelihood.log_likelihood() - jit_likelihood.noise_log_likelihood() - sample_likelihood = jit_likelihood -else: - sample_likelihood = likelihood - -# use the log_compiles context so we can make sure there aren't recompilations -# inside the sampling loop -with jax.log_compiles(): - result = bilby.run_sampler( - likelihood=sample_likelihood, + if use_jax: + match model: + case "relbin": + fdsm = bilby.gw.jaxstuff.ripple_bbh_relbin + case _: + fdsm = bilby.gw.jaxstuff.ripple_bbh + else: + match model: + case "relbin": + fdsm = bilby.gw.source.lal_binary_black_hole_relative_binning + case _: + fdsm = bilby.gw.source.lal_binary_black_hole + + # Create the waveform_generator using a LAL BinaryBlackHole source function + waveform_generator = bilby.gw.WaveformGenerator( + duration=duration, + sampling_frequency=sampling_frequency, + frequency_domain_source_model=fdsm, + parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, + waveform_arguments=waveform_arguments, + use_cache=not use_jax, + ) + + # Set up interferometers. In this case we'll use two interferometers + # (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design + # sensitivity + ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=sampling_frequency, + duration=duration, + start_time=injection_parameters["geocent_time"] - 2, + ) + ifos.inject_signal( + waveform_generator=waveform_generator, parameters=injection_parameters + ) + if use_jax: + ifos.set_array_backend(jax.numpy) + + if model == "mb": + if use_jax: + pass + else: + waveform_generator.frequency_domain_source_model = ( + bilby.gw.source.binary_black_hole_frequency_sequence + ) + del waveform_generator.waveform_arguments["minimum_frequency"] + + # Set up a PriorDict, which inherits from dict. + # By default we will sample all terms in the signal models. However, this will + # take a long time for the calculation, so for this example we will set almost + # all of the priors to be equall to their injected values. This implies the + # prior is a delta function at the true, injected value. In reality, the + # sampler implementation is smart enough to not sample any parameter that has + # a delta-function prior. + # The above list does *not* include mass_1, mass_2, theta_jn and luminosity + # distance, which means those are the parameters that will be included in the + # sampler. If we do nothing, then the default priors get used. + priors = bilby.gw.prior.BBHPriorDict() + for key in [ + # "a_1", + # "a_2", + # "tilt_1", + # "tilt_2", + # "phi_12", + # "phi_jl", + # "psi", + # "ra", + # "dec", + # "geocent_time", + ]: + priors[key] = injection_parameters[key] + del priors["mass_1"], priors["mass_2"] + priors["L1_time"] = bilby.core.prior.Uniform(1126259642.313, 1126259642.513) + + # Perform a check that the prior does not extend to a parameter space longer than the data + if not use_jax: + priors.validate_prior(duration, minimum_frequency) + + # Initialise the likelihood by passing in the interferometer data (ifos) and + # the waveform generator + match model: + case "relbin": + likelihood_class = ( + bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient + ) + case "mb": + likelihood_class = bilby.gw.likelihood.MBGravitationalWaveTransient + case _: + likelihood_class = bilby.gw.likelihood.GravitationalWaveTransient + likelihood = likelihood_class( + interferometers=ifos, + waveform_generator=waveform_generator, priors=priors, - sampler="dynesty", - npoints=100, - sample="acceptance-walk", - naccept=10, - injection_parameters=injection_parameters, - outdir=outdir, - label=label, + phase_marginalization=True, + reference_frame=ifos, + time_reference="L1", ) -# Make a corner plot. -result.plot_corner() + if use_jax: + + def sample(): + parameters = priors.sample() + parameters = {key: jax.numpy.array(val) for key, val in parameters.items()} + return parameters + + # burn a few likelihood calls to check that we don't get + # repeated compilation + likelihood.parameters.update(sample()) + likelihood.log_likelihood_ratio() + likelihood.log_likelihood() + likelihood.noise_log_likelihood() + + with jax.log_compiles(): + jit_likelihood = bilby.gw.jaxstuff.JittedLikelihood( + likelihood, + cast_to_float=False, + jit=True, + ) + jit_likelihood.parameters.update(sample()) + jit_likelihood.log_likelihood_ratio() + jit_likelihood.log_likelihood() + jit_likelihood.noise_log_likelihood() + jit_likelihood.parameters.update(sample()) + jit_likelihood.log_likelihood_ratio() + jit_likelihood.log_likelihood() + jit_likelihood.noise_log_likelihood() + sample_likelihood = jit_likelihood + else: + sample_likelihood = likelihood + + def likelihood_func(parameters): + return sample_likelihood.likelihood_func(parameters, **sample_likelihood.kwargs) + + # import IPython; IPython.embed() + # raise SystemExit() + + # use the log_compiles context so we can make sure there aren't recompilations + # inside the sampling loop + with jax.log_compiles(): + result = bilby.run_sampler( + likelihood=sample_likelihood, + priors=priors, + # sampler="dynesty", + sampler="numpyro", + sampler_name="ESS", + num_warmup=100, + num_samples=100, + num_chains=40, + thinning=2, + # moves={AIES.DEMove(): 0.25, AIES.DEMove(g0=1): 0.5, AIES.StretchMove(): 0.25}, + moves={ + ESS.DifferentialMove(): 0.25, + ESS.KDEMove(): 0.25, + ESS.GaussianMove(): 0.5, + }, + chain_method="vectorized", + npoints=100, + sample="acceptance-walk", + naccept=10, + injection_parameters=injection_parameters, + outdir=outdir, + label=label, + ) + print(result) + print(f"Sampling time: {result.sampling_time:.1f}s\n") + + # Make a corner plot. + result.plot_corner() + raise SystemExit() + return result.sampling_time + + +if __name__ == "__main__": + times = dict() + for arg in product([True, False], ["relbin", "mb", "regular"][-1:]): + times[arg] = main(*arg) + print(times) From c4d9bdf6c74a3194f9755aba40ca4e3c0de1a4a8 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 25 Oct 2024 19:22:34 -0500 Subject: [PATCH 003/110] FEAT: make more conversions backend agnostic --- bilby/gw/conversion.py | 89 ++++++++++++++++++----------------- bilby/gw/detector/networks.py | 24 ++++++++++ bilby/gw/jaxstuff.py | 1 + bilby/gw/likelihood/base.py | 2 + bilby/gw/utils.py | 28 ++--------- 5 files changed, 77 insertions(+), 67 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 87fcfe78e..7c6dbf462 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -9,6 +9,7 @@ import pickle import numpy as np +from bilback.utils import array_module from pandas import DataFrame, Series from scipy.stats import norm @@ -204,9 +205,9 @@ def convert_to_lal_binary_black_hole_parameters(parameters): added_keys: list keys which are added to parameters during function call """ - converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) + xp = array_module(parameters[original_keys[0]]) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ @@ -244,7 +245,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters['a_{}'.format(idx)] = abs( converted_parameters[key]) converted_parameters['cos_tilt_{}'.format(idx)] = \ - np.sign(converted_parameters[key]) + xp.sign(converted_parameters[key]) else: with np.errstate(invalid="raise"): try: @@ -267,13 +268,13 @@ def convert_to_lal_binary_black_hole_parameters(parameters): cos_angle = str('cos_' + angle) if cos_angle in converted_parameters.keys(): with np.errstate(invalid="ignore"): - converted_parameters[angle] = np.arccos(converted_parameters[cos_angle]) + converted_parameters[angle] = xp.arccos(converted_parameters[cos_angle]) if "delta_phase" in original_keys: with np.errstate(invalid="ignore"): - converted_parameters["phase"] = np.mod( + converted_parameters["phase"] = xp.mod( converted_parameters["delta_phase"] - - np.sign(np.cos(converted_parameters["theta_jn"])) + - xp.sign(xp.cos(converted_parameters["theta_jn"])) * converted_parameters["psi"], 2 * np.pi) added_keys = [key for key in converted_parameters.keys() @@ -378,19 +379,19 @@ def convert_to_lal_binary_neutron_star_parameters(parameters): g3pca = converted_parameters['eos_spectral_pca_gamma_3'] m1s = converted_parameters['mass_1_source'] m2s = converted_parameters['mass_2_source'] - all_lambda_1 = np.empty(0) - all_lambda_2 = np.empty(0) - all_eos_check = np.empty(0, dtype=bool) + all_lambda_1 = list() + all_lambda_2 = list() + all_eos_check = list() for (g_0pca, g_1pca, g_2pca, g_3pca, m1_s, m2_s) in zip(g0pca, g1pca, g2pca, g3pca, m1s, m2s): g_0, g_1, g_2, g_3 = spectral_pca_to_spectral(g_0pca, g_1pca, g_2pca, g_3pca) lambda_1, lambda_2, eos_check = \ spectral_params_to_lambda_1_lambda_2(g_0, g_1, g_2, g_3, m1_s, m2_s) - all_lambda_1 = np.append(all_lambda_1, lambda_1) - all_lambda_2 = np.append(all_lambda_2, lambda_2) - all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + all_lambda_1.append(lambda_1) + all_lambda_2.append(lambda_2) + all_eos_check.append(eos_check) + converted_parameters['lambda_1'] = np.array(all_lambda_1) + converted_parameters['lambda_2'] = np.array(all_lambda_2) + converted_parameters['eos_check'] = np.array(all_eos_check) for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] elif 'eos_polytrope_gamma_0' and 'eos_polytrope_log10_pressure_1' in converted_parameters.keys(): @@ -630,8 +631,9 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) array of gamma_0, gamma_1, gamma_2, gamma_3 in model space ''' - sampled_pca_gammas = np.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) - transformation_matrix = np.array( + xp = array_module(gamma_pca_0) + sampled_pca_gammas = xp.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) + transformation_matrix = xp.array( [ [0.43801, -0.76705, 0.45143, 0.12646], [-0.53573, 0.17169, 0.67968, 0.47070], @@ -640,10 +642,10 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) ] ) - model_space_mean = np.array([0.89421, 0.33878, -0.07894, 0.00393]) - model_space_standard_deviation = np.array([0.35700, 0.25769, 0.05452, 0.00312]) + model_space_mean = xp.array([0.89421, 0.33878, -0.07894, 0.00393]) + model_space_standard_deviation = xp.array([0.35700, 0.25769, 0.05452, 0.00312]) converted_gamma_parameters = \ - model_space_mean + model_space_standard_deviation * np.dot(transformation_matrix, sampled_pca_gammas) + model_space_mean + model_space_standard_deviation * xp.dot(transformation_matrix, sampled_pca_gammas) return converted_gamma_parameters @@ -958,9 +960,9 @@ def chirp_mass_and_primary_mass_to_mass_ratio(chirp_mass, mass_1): Mass ratio (mass_2/mass_1) of the binary """ a = (chirp_mass / mass_1) ** 5 - t0 = np.cbrt(9 * a + np.sqrt(3) * np.sqrt(27 * a ** 2 - 4 * a ** 3)) - t1 = np.cbrt(2) * 3 ** (2 / 3) - t2 = np.cbrt(2 / 3) * a + t0 = (9 * a + 3**0.5 * (27 * a ** 2 - 4 * a ** 3)**0.5)**(1 / 3) + t1 = (2)**(1 / 3) * 3 ** (2 / 3) + t2 = (2 / 3)**(1 / 3) * a return t2 / t0 + t0 / t1 @@ -1043,8 +1045,8 @@ def component_masses_to_symmetric_mass_ratio(mass_1, mass_2): symmetric_mass_ratio: float Symmetric mass ratio of the binary """ - - return np.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, 1 / 4) + xp = array_module(mass_1) + return xp.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, 1 / 4) def component_masses_to_mass_ratio(mass_1, mass_2): @@ -1403,17 +1405,17 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s lambda_antisymmetric: float Antisymmetric tidal parameter. """ - lambda_symmetric_m1o5 = np.power(lambda_symmetric, -1. / 5.) + lambda_symmetric_m1o5 = lambda_symmetric ** (-1 / 5) lambda_symmetric_m2o5 = lambda_symmetric_m1o5 * lambda_symmetric_m1o5 lambda_symmetric_m3o5 = lambda_symmetric_m2o5 * lambda_symmetric_m1o5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # Eqn.2 from CHZ, incorporating the dependence on mass ratio n_polytropic = 0.743 # average polytropic index for the EoSs included in the fit - q_for_Fnofq = np.power(q, 10. / (3. - n_polytropic)) + q_for_Fnofq = q ** (10. / (3. - n_polytropic)) Fnofq = (1. - q_for_Fnofq) / (1. + q_for_Fnofq) # b_ij and c_ij coefficients are given in Table I of CHZ @@ -1483,10 +1485,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin lambda_antisymmetric_fitOnly = binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_symmetric, mass_ratio) - lambda_symmetric_sqrt = np.sqrt(lambda_symmetric) + lambda_symmetric_sqrt = lambda_symmetric ** 0.5 q = mass_ratio - q2 = np.square(mass_ratio) + q2 = mass_ratio ** 2 # mu_i and sigma_i coefficients are given in Table II of CHZ @@ -1546,9 +1548,10 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin # Eqn 5 from CHZ, averaging the corrections from the # standard deviations of the residual fits - lambda_antisymmetric_stdCorr = \ - np.sqrt(np.square(lambda_antisymmetric_lambda_symmetric_stdCorr) + - np.square(lambda_antisymmetric_mass_ratio_stdCorr)) + lambda_antisymmetric_stdCorr = ( + lambda_antisymmetric_lambda_symmetric_stdCorr ** 2 + + lambda_antisymmetric_mass_ratio_stdCorr ** 2 + ) ** 0.5 # Draw a correction on the fit from a # Gaussian distribution with width lambda_antisymmetric_stdCorr @@ -2071,28 +2074,29 @@ def generate_spin_parameters(sample): output_sample = sample.copy() output_sample = generate_component_spins(output_sample) + xp = array_module(sample["spin_1z"]) output_sample['chi_eff'] = (output_sample['spin_1z'] + output_sample['spin_2z'] * output_sample['mass_ratio']) /\ (1 + output_sample['mass_ratio']) - output_sample['chi_1_in_plane'] = np.sqrt( + output_sample['chi_1_in_plane'] = ( output_sample['spin_1x'] ** 2 + output_sample['spin_1y'] ** 2 - ) - output_sample['chi_2_in_plane'] = np.sqrt( + ) ** 0.5 + output_sample['chi_2_in_plane'] = ( output_sample['spin_2x'] ** 2 + output_sample['spin_2y'] ** 2 - ) + ) ** 0.5 - output_sample['chi_p'] = np.maximum( + output_sample['chi_p'] = xp.maximum( output_sample['chi_1_in_plane'], (4 * output_sample['mass_ratio'] + 3) / (3 * output_sample['mass_ratio'] + 4) * output_sample['mass_ratio'] * output_sample['chi_2_in_plane']) try: - output_sample['cos_tilt_1'] = np.cos(output_sample['tilt_1']) - output_sample['cos_tilt_2'] = np.cos(output_sample['tilt_2']) + output_sample['cos_tilt_1'] = xp.cos(output_sample['tilt_1']) + output_sample['cos_tilt_2'] = xp.cos(output_sample['tilt_2']) except KeyError: pass @@ -2121,12 +2125,13 @@ def generate_component_spins(sample): ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', 'mass_1', 'mass_2', 'reference_frequency', 'phase'] if all(key in output_sample.keys() for key in spin_conversion_parameters): + xp = array_module(output_sample["theta_jn"]) ( output_sample['iota'], output_sample['spin_1x'], output_sample['spin_1y'], output_sample['spin_1z'], output_sample['spin_2x'], output_sample['spin_2y'], output_sample['spin_2z'] - ) = np.vectorize(bilby_to_lalsimulation_spins)( + ) = xp.vectorize(bilby_to_lalsimulation_spins)( output_sample['theta_jn'], output_sample['phi_jl'], output_sample['tilt_1'], output_sample['tilt_2'], output_sample['phi_12'], output_sample['a_1'], output_sample['a_2'], @@ -2136,10 +2141,10 @@ def generate_component_spins(sample): ) output_sample['phi_1'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_1y'], output_sample['spin_1x']), 2 * np.pi) output_sample['phi_2'] =\ - np.fmod(2 * np.pi + np.arctan2( + xp.fmod(2 * np.pi + xp.arctan2( output_sample['spin_2y'], output_sample['spin_2x']), 2 * np.pi) elif 'chi_1' in output_sample and 'chi_2' in output_sample: diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index b686f2fd5..592e56a93 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -2,6 +2,7 @@ import numpy as np import math +from bilback.geometry import zenith_azimuth_to_theta_phi from ...core import utils from ...core.utils import logger, safe_file_dump @@ -470,3 +471,26 @@ def load_interferometer(filename): "{} could not be loaded. Invalid parameter 'shape'.".format(filename) ) return ifo + + +@zenith_azimuth_to_theta_phi.dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos: InterferometerList | list): + """ + Convert from the 'detector frame' to the Earth frame. + + Parameters + ========== + kappa: float + The zenith angle in the detector frame + eta: float + The azimuthal angle in the detector frame + ifos: list + List of Interferometer objects defining the detector frame + + Returns + ======= + theta, phi: float + The zenith and azimuthal angles in the earth frame. + """ + delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex + return zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) diff --git a/bilby/gw/jaxstuff.py b/bilby/gw/jaxstuff.py index 464a7abb7..f1e1c57b0 100644 --- a/bilby/gw/jaxstuff.py +++ b/bilby/gw/jaxstuff.py @@ -86,6 +86,7 @@ def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): Whether to evaluate the likelihood ratio or the full likelihood. Default is :code:`True`. """ + parameters = {k: jnp.array(v) for k, v in parameters.items()} likelihood.parameters.update(parameters) if use_ratio: return likelihood.log_likelihood_ratio() diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 865944ed0..f7d3df0c8 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -4,6 +4,7 @@ import attr import numpy as np +from bilback.utils import array_module from scipy.special import logsumexp from ...core.likelihood import Likelihood, _fallback_to_parameters @@ -169,6 +170,7 @@ def __init__( if "geocent" not in time_reference: self.time_reference = time_reference self.reference_ifo = get_empty_interferometer(self.time_reference) + self.reference_ifo.set_array_backend(array_module(self.interferometers[0].vertex)) if self.time_marginalization: logger.info("Cannot marginalise over non-geocenter time.") self.time_marginalization = False diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 8ee65e7fe..865c33042 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,9 +5,7 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from bilback.geometry import ( - zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi, -) +from bilback.geometry import zenith_azimuth_to_theta_phi from bilback.time import greenwich_mean_sidereal_time from bilback.utils import array_module from plum import dispatch @@ -230,28 +228,6 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non return sum(integral).real -def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos): - """ - Convert from the 'detector frame' to the Earth frame. - - Parameters - ========== - kappa: float - The zenith angle in the detector frame - eta: float - The azimuthal angle in the detector frame - ifos: list - List of Interferometer objects defining the detector frame - - Returns - ======= - theta, phi: float - The zenith and azimuthal angles in the earth frame. - """ - delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex - return _zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) - - def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): """ Convert from the 'detector frame' to the Earth frame. @@ -272,6 +248,8 @@ def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): ra, dec: float The zenith and azimuthal angles in the sky frame. """ + # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex + # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, ifos) gmst = greenwich_mean_sidereal_time(geocent_time) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) From cd78b34bcbd5a903f484bd50b9d12053681b2f43 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 28 Oct 2024 10:36:55 -0500 Subject: [PATCH 004/110] FEAT: use more normal conversions --- bilby/gw/conversion.py | 47 ++++++++- bilby/gw/likelihood/base.py | 39 ++++---- .../injection_examples/jax_fast_tutorial.py | 99 ++++++++++++++++--- 3 files changed, 150 insertions(+), 35 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 7c6dbf462..093d29ab6 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -179,6 +179,40 @@ def transform_precessing_spins(*args): return lalsim_SimInspiralTransformPrecessingNewInitialConditions(*args) +def convert_orientation_quaternion(parameters): + xp = array_module(parameters["orientation_w"]) + norm = ( + parameters["orientation_w"]**2 + + parameters["orientation_x"]**2 + + parameters["orientation_y"]**2 + + parameters["orientation_z"]**2 + )**0.5 + parameters["theta_jn"] = 2 * xp.arccos( + parameters["orientation_z"] / norm + ) + parameters["psi"] = xp.arctan2( + parameters["orientation_w"], + parameters["orientation_y"] + + parameters["orientation_x"], + ) / 2 + parameters["delta_phase"] = xp.arctan2( + parameters["orientation_y"], + parameters["orientation_x"], + ) / 2 + + +def convert_cartesian(parameters, label): + spin_norm = ( + parameters[f"{label}_x"]**2 + + parameters[f"{label}_y"]**2 + + parameters[f"{label}_z"]**2 + )**0.5 + xp = array_module(spin_norm) + zenith = xp.arccos(parameters[f"{label}_z"] / spin_norm) + azimuth = xp.arctan2(parameters[f"{label}_y"], parameters[f"{label}_x"]) + return zenith, azimuth + + def convert_to_lal_binary_black_hole_parameters(parameters): """ Convert parameters we have into parameters we need. @@ -230,6 +264,14 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters = generate_component_masses(converted_parameters, require_add=False) for idx in ['1', '2']: + if f"spin_{idx}_x" in original_keys: + converted_parameters["tilt_1"], converted_parameters["phi_jl"] = ( + convert_cartesian(converted_parameters, "spin_1") + ) + converted_parameters["tilt_2"], converted_parameters["phi_12"] = ( + convert_cartesian(converted_parameters, "spin_2") + ) + converted_parameters["phi_12"] -= converted_parameters["phi_jl"] key = 'chi_{}'.format(idx) if key in original_keys: if "chi_{}_in_plane".format(idx) in original_keys: @@ -260,6 +302,9 @@ def convert_to_lal_binary_black_hole_parameters(parameters): ) converted_parameters[f"cos_tilt_{idx}"] = 1.0 + if "orientation_w" in original_keys: + convert_orientation_quaternion(converted_parameters) + for key in ["phi_jl", "phi_12"]: if key not in converted_parameters: converted_parameters[key] = 0.0 @@ -270,7 +315,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): with np.errstate(invalid="ignore"): converted_parameters[angle] = xp.arccos(converted_parameters[cos_angle]) - if "delta_phase" in original_keys: + if "delta_phase" in converted_parameters: with np.errstate(invalid="ignore"): converted_parameters["phase"] = xp.mod( converted_parameters["delta_phase"] diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index f7d3df0c8..ca6d64aba 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -411,11 +411,11 @@ def noise_log_likelihood(self): self._noise_log_likelihood_value = self._calculate_noise_log_likelihood() return self._noise_log_likelihood_value - def log_likelihood_ratio(self, parameters=None): if parameters is not None: parameters = copy.deepcopy(parameters) else: parameters = _fallback_to_parameters(self, parameters) + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if waveform_polarizations is None: @@ -424,8 +424,6 @@ def log_likelihood_ratio(self, parameters=None): if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - total_snrs = self._CalculatedSNRs() for interferometer in self.interferometers: @@ -476,14 +474,13 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): def compute_per_detector_log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) + parameters.update(self.get_sky_frame_parameters(parameters)) waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(parameters) if self.time_marginalization and self.jitter_time: parameters['geocent_time'] += parameters['time_jitter'] - parameters.update(self.get_sky_frame_parameters(parameters)) - for interferometer in self.interferometers: per_detector_snr = self.calculate_snrs( waveform_polarizations=waveform_polarizations, @@ -1111,6 +1108,7 @@ def get_sky_frame_parameters(self, parameters=None): dict: dictionary containing ra, dec, and geocent_time """ parameters = _fallback_to_parameters(self, parameters) + convert_orientation_quaternion(parameters) time = parameters.get(f'{self.time_reference}_time', None) if time is None and "geocent_time" in parameters: logger.warning( @@ -1118,20 +1116,25 @@ def get_sky_frame_parameters(self, parameters=None): "Falling back to geocent time" ) if not self.reference_frame == "sky": - try: + if "sky_x" in parameters: + zenith, azimuth = convert_cartesian(parameters, "sky") + elif "zenith" in parameters: + zenith = parameters["zenith"] + azimuth = parameters["azimuth"] + elif "ra" in parameters and "dec" in parameters: + ra = parameters["ra"] + dec = parameters["dec"] + logger.warning( + "Cannot convert from zenith/azimuth to ra/dec falling " + "back to provided ra/dec" + ) + zenith = None + else: + raise KeyError("No sky location parameters recognised") + if zenith is not None: ra, dec = zenith_azimuth_to_ra_dec( - parameters['zenith'], parameters['azimuth'], - time, self.reference_frame) - except KeyError: - if "ra" in parameters and "dec" in parameters: - ra = parameters["ra"] - dec = parameters["dec"] - logger.warning( - "Cannot convert from zenith/azimuth to ra/dec falling " - "back to provided ra/dec" - ) - else: - raise + zenith, azimuth, time, self.reference_frame + ) else: ra = parameters["ra"] dec = parameters["dec"] diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index 930cc6de5..b1b586ecf 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -18,11 +18,14 @@ import bilby import bilby.gw.jaxstuff import jax +import jax.numpy as jnp +from jax import random from numpyro.infer import AIES, ESS # noqa +from numpyro.infer.ensemble_util import get_nondiagonal_indices jax.config.update("jax_enable_x64", True) -bilby.core.utils.setup_logger(log_level="WARNING") +bilby.core.utils.setup_logger() # log_level="WARNING") def main(use_jax, model): @@ -146,7 +149,28 @@ def main(use_jax, model): ]: priors[key] = injection_parameters[key] del priors["mass_1"], priors["mass_2"] - priors["L1_time"] = bilby.core.prior.Uniform(1126259642.313, 1126259642.513) + priors["L1_time"] = bilby.core.prior.Uniform(1126259642.41, 1126259642.45) + del priors["ra"], priors["dec"] + # priors["zenith"] = bilby.core.prior.Cosine() + # priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) + priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["delta_phase"] = priors.pop("phase") + priors["chirp_mass"].minimum = 20 + priors["chirp_mass"].maximum = 35 + del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"] + priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + del priors["theta_jn"], priors["psi"], priors["delta_phase"] + priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1) # Perform a check that the prior does not extend to a parameter space longer than the data if not use_jax: @@ -167,7 +191,7 @@ def main(use_jax, model): interferometers=ifos, waveform_generator=waveform_generator, priors=priors, - phase_marginalization=True, + # phase_marginalization=True, reference_frame=ifos, time_reference="L1", ) @@ -190,7 +214,6 @@ def sample(): jit_likelihood = bilby.gw.jaxstuff.JittedLikelihood( likelihood, cast_to_float=False, - jit=True, ) jit_likelihood.parameters.update(sample()) jit_likelihood.log_likelihood_ratio() @@ -216,29 +239,36 @@ def likelihood_func(parameters): result = bilby.run_sampler( likelihood=sample_likelihood, priors=priors, - # sampler="dynesty", - sampler="numpyro", + sampler="dynesty", + # sampler="numpyro", sampler_name="ESS", - num_warmup=100, - num_samples=100, - num_chains=40, - thinning=2, - # moves={AIES.DEMove(): 0.25, AIES.DEMove(g0=1): 0.5, AIES.StretchMove(): 0.25}, + # sampler_name="NUTS", + num_warmup=500, + num_samples=500, + num_chains=100, + thinning=5, + # moves={ + # AIES.DEMove(): 0.35, + # ModeHopping(): 0.3, + # AIES.StretchMove(): 0.35, + # }, moves={ ESS.DifferentialMove(): 0.25, ESS.KDEMove(): 0.25, ESS.GaussianMove(): 0.5, }, chain_method="vectorized", - npoints=100, - sample="acceptance-walk", + npoints=500, + # sample="acceptance-walk", + sample="act-walk", naccept=10, injection_parameters=injection_parameters, outdir=outdir, label=label, + npool=4, ) - print(result) - print(f"Sampling time: {result.sampling_time:.1f}s\n") + # print(result) + # print(f"Sampling time: {result.sampling_time:.1f}s\n") # Make a corner plot. result.plot_corner() @@ -246,8 +276,45 @@ def likelihood_func(parameters): return result.sampling_time +def ModeHopping(): + """ + A proposal using differential evolution. + + This `Differential evolution proposal + `_ is + implemented following `Nelson et al. (2013) + `_. + + :param sigma: (optional) + The standard deviation of the Gaussian used to stretch the proposal vector. + Defaults to `1.0.e-5`. + :param g0 (optional): + The mean stretch factor for the proposal vector. By default, + it is `2.38 / sqrt(2*ndim)` as recommended by the two references. + """ + + def make_de_move(n_chains): + PAIRS = get_nondiagonal_indices(n_chains // 2) + + def de_move(rng_key, active, inactive): + n_active_chains, _ = inactive.shape + + selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,)) + + # Compute diff vectors + diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1) + + proposal = active + diffs + + return proposal, jnp.zeros(n_active_chains) + + return de_move + + return make_de_move + + if __name__ == "__main__": times = dict() - for arg in product([True, False], ["relbin", "mb", "regular"][-1:]): + for arg in product([True, False][1:], ["relbin", "mb", "regular"][1:2]): times[arg] = main(*arg) print(times) From 68abf3c4b6146b68a121d60559af0ca6d06a86bc Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 13 Nov 2024 13:25:36 -0800 Subject: [PATCH 005/110] FEAT: move backend switching code to bilby --- bilby/compat/__init__.py | 0 bilby/compat/types.py | 6 + bilby/compat/utils.py | 22 +++ bilby/core/utils/series.py | 2 +- bilby/gw/__init__.py | 1 + bilby/gw/compat/__init__.py | 4 + bilby/gw/compat/jax.py | 18 ++ bilby/gw/detector/geometry.py | 2 +- bilby/gw/detector/interferometer.py | 13 +- bilby/gw/detector/networks.py | 2 +- bilby/gw/geometry.py | 258 +++++++++++++++++++++++++++ bilby/gw/likelihood/base.py | 2 +- bilby/gw/likelihood/multiband.py | 2 +- bilby/gw/likelihood/relative.py | 2 +- bilby/gw/time.py | 259 ++++++++++++++++++++++++++++ bilby/gw/utils.py | 6 +- 16 files changed, 583 insertions(+), 16 deletions(-) create mode 100644 bilby/compat/__init__.py create mode 100644 bilby/compat/types.py create mode 100644 bilby/compat/utils.py create mode 100644 bilby/gw/compat/__init__.py create mode 100644 bilby/gw/compat/jax.py create mode 100644 bilby/gw/geometry.py create mode 100644 bilby/gw/time.py diff --git a/bilby/compat/__init__.py b/bilby/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bilby/compat/types.py b/bilby/compat/types.py new file mode 100644 index 000000000..8a3391c44 --- /dev/null +++ b/bilby/compat/types.py @@ -0,0 +1,6 @@ +from typing import Union +import numpy as np + +Real = Union[float, int] +ArrayLike = Union[np.ndarray, list, tuple] + diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py new file mode 100644 index 000000000..2943c1ae6 --- /dev/null +++ b/bilby/compat/utils.py @@ -0,0 +1,22 @@ +import numpy as np +from array_api_compat import array_namespace + +__all__ = ["array_module", "promote_to_array"] + + +def array_module(arr): + if arr.__class__.__module__ == "builtins": + return np + else: + return array_namespace(arr) + + +def promote_to_array(args, backend, skip=None): + if skip is None: + skip = len(args) + else: + skip = len(args) - skip + if backend.__name__ != "numpy": + args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:] + return args + diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index c3d71e3e2..8affa61be 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -1,5 +1,5 @@ import numpy as np -from bilback.utils import array_module +from ...compat.utils import array_module _TOL = 14 diff --git a/bilby/gw/__init__.py b/bilby/gw/__init__.py index b5115766b..cd09bc6f6 100644 --- a/bilby/gw/__init__.py +++ b/bilby/gw/__init__.py @@ -3,4 +3,5 @@ from .waveform_generator import WaveformGenerator, LALCBCWaveformGenerator from .likelihood import GravitationalWaveTransient from .detector import calibration +from . import compat diff --git a/bilby/gw/compat/__init__.py b/bilby/gw/compat/__init__.py new file mode 100644 index 000000000..8e2e63c62 --- /dev/null +++ b/bilby/gw/compat/__init__.py @@ -0,0 +1,4 @@ +try: + from .jax import n_leap_seconds +except ModuleNotFoundError: + pass diff --git a/bilby/gw/compat/jax.py b/bilby/gw/compat/jax.py new file mode 100644 index 000000000..9b0732112 --- /dev/null +++ b/bilby/gw/compat/jax.py @@ -0,0 +1,18 @@ +import jax.numpy as jnp +from jax import Array +from plum import dispatch + +from ..time import LEAP_SECONDS as _LEAP_SECONDS, n_leap_seconds + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = jnp.array(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: Array): + """ + Find the number of leap seconds required for the specified date. + """ + return n_leap_seconds(date, LEAP_SECONDS) + diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index ea2a509ab..627f6a143 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -1,5 +1,5 @@ import numpy as np -from bilback.geometry import calculate_arm, detector_tensor +from ..geometry import calculate_arm, detector_tensor from .. import utils as gwutils diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 2c450cff2..437cf6abf 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -1,17 +1,16 @@ import os import numpy as np -from bilback.geometry import ( - get_polarization_tensor, - three_by_three_matrix_contraction, - time_delay_from_geocenter, -) -from bilback.utils import array_module from ...core import utils from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump -from ...core.utils.env import string_to_boolean +from ...compat.utils import array_module from .. import utils as gwutils +from ..geometry import ( + get_polarization_tensor, + three_by_three_matrix_contraction, + time_delay_from_geocenter, +) from .calibration import Recalibrate from .geometry import InterferometerGeometry from .strain_data import InterferometerStrainData diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 592e56a93..988c1b76e 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -2,10 +2,10 @@ import numpy as np import math -from bilback.geometry import zenith_azimuth_to_theta_phi from ...core import utils from ...core.utils import logger, safe_file_dump +from ..geometry import zenith_azimuth_to_theta_phi from .interferometer import Interferometer from .psd import PowerSpectralDensity diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py new file mode 100644 index 000000000..60d55f8be --- /dev/null +++ b/bilby/gw/geometry.py @@ -0,0 +1,258 @@ +import numpy as np +from plum import dispatch +from bilby_rust import geometry as _geometry + +from .time import greenwich_mean_sidereal_time +from ..compat.types import Real, ArrayLike +from ..compat.utils import array_module, promote_to_array + + +__all__ = [ + "antenna_response", + "calculate_arm", + "detector_tensor", + "get_polarization_tensor", + "get_polarization_tensor_multiple_modes", + "rotation_matrix_from_delta", + "three_by_three_matrix_contraction", + "time_delay_geocentric", + "time_delay_from_geocenter", + "zenith_azimuth_to_theta_phi", +] + + +@dispatch +def antenna_response(detector_tensor, ra, dec, time, psi, mode): + """""" + xp = array_module(detector_tensor) + polarization_tensor = get_polarization_tensor(*promote_to_array((ra, dec, time, psi), xp), mode) + return three_by_three_matrix_contraction(detector_tensor, polarization_tensor) + + +@dispatch +def calculate_arm(arm_tilt, arm_azimuth, longitude, latitude): + """""" + xp = array_module(arm_tilt) + e_long = xp.array([-xp.sin(longitude), xp.cos(longitude), longitude * 0]) + e_lat = xp.array( + [ + -xp.sin(latitude) * xp.cos(longitude), + -xp.sin(latitude) * xp.sin(longitude), + xp.cos(latitude), + ] + ) + e_h = xp.array( + [ + xp.cos(latitude) * xp.cos(longitude), + xp.cos(latitude) * xp.sin(longitude), + xp.sin(latitude), + ] + ) + + return ( + xp.cos(arm_tilt) * xp.cos(arm_azimuth) * e_long + + xp.cos(arm_tilt) * xp.sin(arm_azimuth) * e_lat + + xp.sin(arm_tilt) * e_h + ) + + +@dispatch +def detector_tensor(x, y): + """""" + xp = array_module(x) + return (xp.outer(x, x) - xp.outer(y, y)) / 2 + + +@dispatch +def get_polarization_tensor(ra, dec, time, psi, mode): + """""" + from functools import partial + + xp = array_module(ra) + + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + phi = ra - gmst + theta = xp.atleast_1d(xp.pi / 2 - dec).squeeze() + u = xp.array( + [ + xp.cos(phi) * xp.cos(theta), + xp.cos(theta) * xp.sin(phi), + -xp.sin(theta) * xp.ones_like(phi), + ] + ) + v = xp.array([ + -xp.sin(phi), xp.cos(phi), xp.zeros_like(phi) + ]) * xp.ones_like(theta) + omega = xp.array([ + xp.sin(xp.pi - theta) * xp.cos(xp.pi + phi), + xp.sin(xp.pi - theta) * xp.sin(xp.pi + phi), + xp.cos(xp.pi - theta) * xp.ones_like(phi), + ]) + m = -u * xp.sin(psi) - v * xp.cos(psi) + n = -u * xp.cos(psi) + v * xp.sin(psi) + if xp.__name__ == "mlx.core": + einsum_shape = "i,j->ij" + else: + einsum_shape = "i...,j...->ij..." + product = partial(xp.einsum, einsum_shape) + + match mode.lower(): + case "plus": + return product(m, m) - product(n, n) + case "cross": + return product(m, n) + product(n, m) + case "breathing": + return product(m, m) + product(n, n) + case "longitudinal": + return product(omega, omega) + case "x": + return product(m, omega) + product(omega, m) + case "y": + return product(n, omega) + product(omega, n) + case _: + raise ValueError(f"{mode} not a polarization mode!") + + +@dispatch +def get_polarization_tensor_multiple_modes(ra, dec, time, psi, modes): + """""" + return [get_polarization_tensor(ra, dec, time, psi, mode) for mode in modes] + + +@dispatch +def rotation_matrix_from_delta(delta_x): + """""" + xp = array_module(delta_x) + delta_x = delta_x / (delta_x**2).sum() ** 0.5 + alpha = xp.arctan2(-delta_x[1] * delta_x[2], delta_x[0]) + beta = xp.arccos(delta_x[2]) + gamma = xp.arctan2(delta_x[1], delta_x[0]) + rotation_1 = xp.array( + [ + [xp.cos(alpha), -xp.sin(alpha), xp.zeros(alpha.shape)], + [xp.sin(alpha), xp.cos(alpha), xp.zeros(alpha.shape)], + [xp.zeros(alpha.shape), xp.zeros(alpha.shape), xp.ones(alpha.shape)], + ] + ) + rotation_2 = xp.array( + [ + [xp.cos(beta), xp.zeros(beta.shape), xp.sin(beta)], + [xp.zeros(beta.shape), xp.ones(beta.shape), xp.zeros(beta.shape)], + [-xp.sin(beta), xp.zeros(beta.shape), xp.cos(beta)], + ] + ) + rotation_3 = xp.array( + [ + [xp.cos(gamma), -xp.sin(gamma), xp.zeros(gamma.shape)], + [xp.sin(gamma), xp.cos(gamma), xp.zeros(gamma.shape)], + [xp.zeros(gamma.shape), xp.zeros(gamma.shape), xp.ones(gamma.shape)], + ] + ) + return rotation_3 @ rotation_2 @ rotation_1 + + + +@dispatch +def three_by_three_matrix_contraction(a, b): + """""" + xp = array_module(a) + return xp.einsum("ij,ij->", a, b) + + +@dispatch +def time_delay_geocentric(detector1, detector2, ra, dec, time): + """""" + xp = array_module(detector1) + gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) + speed_of_light = 299792458.0 + phi = ra - gmst + theta = xp.pi / 2 - dec + omega = xp.array( + [xp.sin(theta) * xp.cos(phi), xp.sin(theta) * xp.sin(phi), xp.cos(theta)] + ) + delta_d = detector2 - detector1 + return omega @ delta_d / speed_of_light + + + +@dispatch +def time_delay_from_geocenter(detector1, ra, dec, time): + """""" + xp = array_module(detector1) + return time_delay_geocentric(detector1, xp.zeros(3), ra, dec, time) + + +@dispatch +def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): + """""" + xp = array_module(delta_x) + omega_prime = xp.array( + [ + xp.sin(zenith) * xp.cos(azimuth), + xp.sin(zenith) * xp.sin(azimuth), + xp.cos(zenith), + ] + ) + rotation_matrix = rotation_matrix_from_delta(delta_x) + omega = rotation_matrix @ omega_prime + theta = xp.arccos(omega[2]) + phi = xp.arctan2(omega[1], omega[0]) % (2 * xp.pi) + return theta, phi + + + +# @dispatch(precedence=1) +# def antenna_response(detector_tensor: np.ndarray, ra: FloatOrInt, dec: FloatOrInt, time: FloatOrInt, psi: FloatOrInt, mode: str): +# return _geometry.antenna_response(detector_tensor, ra, dec, time, psi, mode) + + +@dispatch(precedence=1) +def calculate_arm(arm_tilt: Real, arm_azimuth: Real, longitude: Real, latitude: Real): + return _geometry.calculate_arm(arm_tilt, arm_azimuth, longitude, latitude) + + +@dispatch(precedence=1) +def detector_tensor(x: ArrayLike, y: ArrayLike): + return _geometry.detector_tensor(x, y) + + +@dispatch(precedence=1) +def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: str): + return _geometry.get_polarization_tensor(ra, dec, time, psi, mode) + + +# @dispatch(precedence=1) +# def get_polarization_tensor_multiple_modes(ra: FloatOrInt, dec: FloatOrInt, time: FloatOrInt, psi: FloatOrInt, modes: list[str]): +# return [geometry.get_polarization_tensor(ra, dec, time, psi, mode) for mode in modes] + + +@dispatch(precedence=1) +def rotation_matrix_from_delta(delta: ArrayLike): + return _geometry.rotation_matrix_from_delta_x(delta) + + +# @dispatch(precedence=1) +# def three_by_three_matrix_contraction(a: ArrayLike, b: ArrayLike): +# return _geometry.three_by_three_matrix_contraction(a, b) + + +@dispatch(precedence=1) +def time_delay_geocentric(detector1: ArrayLike, detector2: ArrayLike, ra, dec, time): + return _geometry.time_delay_geocentric(detector1, detector2, ra, dec, time) + + +@dispatch(precedence=1) +def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: Real): + return _geometry.time_delay_from_geocenter(detector1, ra, dec, time) + + +@dispatch(precedence=1) +def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: ArrayLike): + return _geometry.time_delay_from_geocenter_vectorized(detector1, ra, dec, time) + + +@dispatch(precedence=1) +def zenith_azimuth_to_theta_phi(zenith: Real, azimuth: Real, delta_x: np.ndarray): + theta, phi = _geometry.zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + return theta, phi % (2 * np.pi) + diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index ca6d64aba..d98d4f481 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -4,9 +4,9 @@ import attr import numpy as np -from bilback.utils import array_module from scipy.special import logsumexp +from ...compat.utils import array_module from ...core.likelihood import Likelihood, _fallback_to_parameters from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series from ...core.prior import Interped, Prior, Uniform, DeltaFunction diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index e2df98fb2..46d29c915 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -3,9 +3,9 @@ import numbers import numpy as np -from bilback.utils import array_module from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, speed_of_light, solar_mass, radius_of_earth, gravitational_constant, round_up_to_power_of_two, diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index 199e0e439..48a4cd585 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -2,9 +2,9 @@ import numpy as np from scipy.optimize import differential_evolution -from bilback.utils import array_module from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import logger from ...core.prior.base import Constraint from ...core.prior import DeltaFunction diff --git a/bilby/gw/time.py b/bilby/gw/time.py new file mode 100644 index 000000000..cb56b3bab --- /dev/null +++ b/bilby/gw/time.py @@ -0,0 +1,259 @@ +from typing import Union + +import numpy as np +from plum import dispatch +from bilby_rust import time as _time + +from ..compat.types import Real, ArrayLike +from ..compat.utils import array_module + + +__all__ = [ + "datetime", + "gps_time_to_utc", + "greenwich_mean_sidereal_time", + "greenwich_sidereal_time", + "n_leap_seconds", + "utc_to_julian_day", + "LEAP_SECONDS", +] + + +class datetime: + """ + A barebones datetime class for use in the GPS to GMST conversion. + """ + + def __init__( + self, + year: int = 0, + month: int = 0, + day: int = 0, + hour: int = 0, + minute: int = 0, + second: float = 0, + ): + self.year = year + self.month = month + self.day = day + self.hour = hour + self.minute = minute + self.second = second + + def __repr__(self): + return f"{self.year}-{self.month}-{self.day} {self.hour}:{self.minute}:{self.second}" + + def __add__(self, other): + """ + Add two datetimes together. + Note that this does not handle overflow and can lead to unphysical + values for the various attributes. + """ + return datetime( + self.year + other.year, + self.month + other.month, + self.day + other.day, + self.hour + other.hour, + self.minute + other.minute, + self.second + other.second, + ) + + @property + def julian_day(self): + return ( + 367 * self.year + - 7 * (self.year + (self.month + 9) // 12) // 4 + + 275 * self.month // 9 + + self.day + + self.second / SECONDS_PER_DAY + + JULIAN_GPS_EPOCH + ) + + +GPS_EPOCH = datetime(1980, 1, 6, 0, 0, 0) +JULIAN_GPS_EPOCH = 1721013.5 +EPOCH_J2000_0_JD = 2451545.0 +DAYS_PER_CENTURY = 36525.0 +SECONDS_PER_DAY = 86400.0 +LEAP_SECONDS = [ + 46828800, + 78364801, + 109900802, + 173059203, + 252028804, + 315187205, + 346723206, + 393984007, + 425520008, + 457056009, + 504489610, + 551750411, + 599184012, + 820108813, + 914803214, + 1025136015, + 1119744016, + 1167264017, +] + + +@dispatch +def gps_time_to_utc(gps_time): + """ + Convert GPS time to UTC. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + datetime + UTC time. + """ + return GPS_EPOCH + datetime(second=gps_time - n_leap_seconds(gps_time)) + + +@dispatch +def greenwich_mean_sidereal_time(gps_time): + """ + Calculate the Greenwich Mean Sidereal Time. + + This is a thin wrapper around :py:func:`greenwich_sidereal_time` with the + equation of the equinoxes set to zero. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + float + Greenwich Mean Sidereal Time in radians. + """ + return greenwich_sidereal_time(gps_time, gps_time * 0) + + +@dispatch +def greenwich_sidereal_time(gps_time, equation_of_equinoxes): + """ + Calculate the Greenwich Sidereal Time. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + equation_of_equinoxes : float + Equation of the equinoxes in seconds. + + Returns + ------- + float + """ + julian_day = utc_to_julian_day(gps_time_to_utc(gps_time // 1)) + t_hi = (julian_day - EPOCH_J2000_0_JD) / DAYS_PER_CENTURY + t_lo = (gps_time % 1) / (DAYS_PER_CENTURY * SECONDS_PER_DAY) + + t = t_hi + t_lo + + sidereal_time = ( + equation_of_equinoxes + (-6.2e-6 * t + 0.093104) * t**2 + 67310.54841 + ) + sidereal_time += 8640184.812866 * t_lo + sidereal_time += 3155760000.0 * t_lo + sidereal_time += 8640184.812866 * t_hi + sidereal_time += 3155760000.0 * t_hi + + return sidereal_time * 2 * np.pi / SECONDS_PER_DAY + + +@dispatch +def n_leap_seconds(gps_time, leap_seconds): + """ + Calculate the number of leap seconds that have occurred up to a given GPS time. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + leap_seconds : array_like + GPS time of leap seconds. + + Returns + ------- + float + Number of leap seconds + """ + xp = array_module(gps_time) + return xp.sum(gps_time > leap_seconds[:, None], axis=0).squeeze() + + +@dispatch +def n_leap_seconds(gps_time: Union[np.ndarray, float, int]): + """ + Calculate the number of leap seconds that have occurred up to a given GPS time. + + Parameters + ---------- + gps_time : float + GPS time in seconds. + + Returns + ------- + float + Number of leap seconds + """ + xp = array_module(gps_time) + return n_leap_seconds(gps_time, xp.array(LEAP_SECONDS)) + + +@dispatch +def utc_to_julian_day(utc_time): + """ + Convert UTC time to Julian day. + + Parameters + ---------- + utc_time : datetime + UTC time. + + Returns + ------- + float + Julian day. + + """ + return utc_time.julian_day + + +@dispatch(precedence=1) +def gps_time_to_utc(gps_time: Real): + return _time.gps_time_to_utc(gps_time) + + +@dispatch(precedence=1) +def greenwich_mean_sidereal_time(gps_time: Real): + return _time.greenwich_mean_sidereal_time(gps_time) + + +@dispatch(precedence=1) +def greenwich_mean_sidereal_time(gps_time: ArrayLike): + return _time.greenwich_mean_sidereal_time_vectorized(gps_time) + + +@dispatch(precedence=1) +def greenwich_sidereal_time(gps_time: Real, equation_of_equinoxes: Real): + return _time.greenwich_sidereal_time(gps_time, equation_of_equinoxes) + + +@dispatch(precedence=1) +def n_leap_seconds(gps_time: Real): + return _time.n_leap_seconds(gps_time) + + +@dispatch(precedence=1) +def utc_to_julian_day(utc_time: Real): + return _time.utc_to_julian_day(utc_time) + diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 865c33042..719664635 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,11 +5,11 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from bilback.geometry import zenith_azimuth_to_theta_phi -from bilback.time import greenwich_mean_sidereal_time -from bilback.utils import array_module from plum import dispatch +from .geometry import zenith_azimuth_to_theta_phi +from .time import greenwich_mean_sidereal_time +from ...compat.utils import array_module from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) From 47041a15f22f9230213b050b7dfaac2f87cafc88 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 14 Nov 2024 08:29:15 -0800 Subject: [PATCH 006/110] FEAT: make core prior backend agnostic --- bilby/compat/utils.py | 14 +- bilby/core/prior/analytical.py | 338 +++++++++++++++++++-------------- bilby/core/prior/dict.py | 16 +- 3 files changed, 214 insertions(+), 154 deletions(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 2943c1ae6..bd97f828e 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -1,5 +1,5 @@ import numpy as np -from array_api_compat import array_namespace +from scipy._lib._array_api import array_namespace __all__ = ["array_module", "promote_to_array"] @@ -20,3 +20,15 @@ def promote_to_array(args, backend, skip=None): args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:] return args + +def xp_wrap(func): + + def wrapped(self, *args, **kwargs): + if "xp" not in kwargs: + try: + kwargs["xp"] = array_module(*args) + except TypeError: + pass + return func(self, *args, **kwargs) + + return wrapped diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index bc47cf680..974d5cbe4 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1,21 +1,24 @@ +import os + import numpy as np +os.environ["SCIPY_ARRAY_API"] = "1" # noqa # flag for scipy backend switching from scipy.special import ( - xlogy, - erf, - erfinv, - log1p, - stdtrit, - gammaln, - stdtr, - betaln, betainc, betaincinv, + betaln, + erf, + # erfinv, # erfinv is not currently backend agnostic gammaincinv, gammainc, + gammaln, + stdtr, + stdtrit, + xlogy, ) from .base import Prior from ..utils import logger +from ...compat.utils import xp_wrap class DeltaFunction(Prior): @@ -67,10 +70,10 @@ def prob(self, val): """ at_peak = (val == self.peak) - return np.nan_to_num(np.multiply(at_peak, np.inf)) + return at_peak * 1.0 def cdf(self, val): - return np.ones_like(val) * (val > self.peak) + return 1.0 * (val > self.peak) class PowerLaw(Prior): @@ -101,7 +104,8 @@ def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, boundary=boundary) self.alpha = alpha - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -117,12 +121,13 @@ def rescale(self, val): Union[float, array_like]: Rescaled probability """ if self.alpha == -1: - return self.minimum * np.exp(val * np.log(self.maximum / self.minimum)) + return self.minimum * xp.exp(val * xp.log(self.maximum / self.minimum)) else: return (self.minimum ** (1 + self.alpha) + val * (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val Parameters @@ -134,13 +139,14 @@ def prob(self, val): float: Prior probability of val """ if self.alpha == -1: - return np.nan_to_num(1 / val / np.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) + return xp.nan_to_num(1 / val / xp.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) else: - return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / + return xp.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) * self.is_in_prior_range(val) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the logarithmic prior probability of val Parameters @@ -153,28 +159,27 @@ def ln_prob(self, val): """ if self.alpha == -1: - normalising = 1. / np.log(self.maximum / self.minimum) + normalising = 1. / xp.log(self.maximum / self.minimum) else: normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha)) with np.errstate(divide='ignore', invalid='ignore'): - ln_in_range = np.log(1. * self.is_in_prior_range(val)) - ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising) + ln_in_range = xp.log(1. * self.is_in_prior_range(val)) + ln_p = self.alpha * xp.nan_to_num(xp.log(val)) + xp.log(normalising) return ln_p + ln_in_range - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): if self.alpha == -1: - _cdf = (np.log(val / self.minimum) / - np.log(self.maximum / self.minimum)) + _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) else: _cdf = ( - (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1) / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) ) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) + _cdf = xp.clip(_cdf, 0, 1) return _cdf @@ -233,7 +238,8 @@ def prob(self, val): """ return ((val >= self.minimum) & (val <= self.maximum)) / (self.maximum - self.minimum) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the log prior probability of val Parameters @@ -244,13 +250,13 @@ def ln_prob(self, val): ======= float: log probability of val """ - return xlogy(1, (val >= self.minimum) & (val <= self.maximum)) - xlogy(1, self.maximum - self.minimum) + with np.errstate(divide="ignore"): + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): _cdf = (val - self.minimum) / (self.maximum - self.minimum) - _cdf = np.minimum(_cdf, 1) - _cdf = np.maximum(_cdf, 0) - return _cdf + return xp.clip(_cdf, 0, 1) class LogUniform(PowerLaw): @@ -310,7 +316,8 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -327,19 +334,20 @@ def rescale(self, val): """ if isinstance(val, (float, int)): if val < 0.5: - return -self.maximum * np.exp(-2 * val * np.log(self.maximum / self.minimum)) + return -self.maximum * xp.exp(-2 * val * xp.log(self.maximum / self.minimum)) else: - return self.minimum * np.exp(np.log(self.maximum / self.minimum) * (2 * val - 1)) + return self.minimum * xp.exp(xp.log(self.maximum / self.minimum) * (2 * val - 1)) else: vals_less_than_5 = val < 0.5 - rescaled = np.empty_like(val) - rescaled[vals_less_than_5] = -self.maximum * np.exp(-2 * val[vals_less_than_5] * - np.log(self.maximum / self.minimum)) - rescaled[~vals_less_than_5] = self.minimum * np.exp(np.log(self.maximum / self.minimum) * + rescaled = xp.empty_like(val) + rescaled[vals_less_than_5] = -self.maximum * xp.exp(-2 * val[vals_less_than_5] * + xp.log(self.maximum / self.minimum)) + rescaled[~vals_less_than_5] = self.minimum * xp.exp(xp.log(self.maximum / self.minimum) * (2 * val[~vals_less_than_5] - 1)) return rescaled - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val Parameters @@ -350,11 +358,12 @@ def prob(self, val): ======= float: Prior probability of val """ - val = np.abs(val) - return (np.nan_to_num(0.5 / val / np.log(self.maximum / self.minimum)) * + val = xp.abs(val) + return (xp.nan_to_num(0.5 / val / xp.log(self.maximum / self.minimum)) * self.is_in_prior_range(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the logarithmic prior probability of val Parameters @@ -366,19 +375,12 @@ def ln_prob(self, val): float: """ - return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) + return np.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(self.maximum / self.minimum))) - def cdf(self, val): - norm = 0.5 / np.log(self.maximum / self.minimum) - _cdf = ( - -norm * np.log(abs(val) / self.maximum) - * (val <= -self.minimum) * (val >= -self.maximum) - + (0.5 + norm * np.log(abs(val) / self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 0.5 * (val > -self.minimum) * (val < self.minimum) - + 1 * (val > self.maximum) - ) - return _cdf + @xp_wrap + def cdf(self, val, *, xp=np): + asymmetric = LogUniform.cdf(self, abs(val), xp) + return 0.5 * (1 + xp.sign(val) * asymmetric) class Cosine(Prior): @@ -405,16 +407,18 @@ def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, super(Cosine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to a uniform in cosine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.sin(self.maximum) - np.sin(self.minimum)) - return np.arcsin(val / norm + np.sin(self.minimum)) + norm = 1 / (xp.sin(self.maximum) - xp.sin(self.minimum)) + return xp.arcsin(val / norm + xp.sin(self.minimum)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Defined over [-pi/2, pi/2]. Parameters @@ -425,15 +429,17 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.cos(val) / 2 * self.is_in_prior_range(val) + return xp.cos(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): _cdf = ( - (np.sin(val) - np.sin(self.minimum)) - / (np.sin(self.maximum) - np.sin(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.sin(val) - xp.sin(self.minimum)) / + (xp.sin(self.maximum) - xp.sin(self.minimum)) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -461,16 +467,18 @@ def __init__(self, minimum=0, maximum=np.pi, name=None, super(Sine, self).__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to a uniform in sine prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (np.cos(self.minimum) - np.cos(self.maximum)) - return np.arccos(np.cos(self.minimum) - val / norm) + norm = 1 / (xp.cos(self.minimum) - xp.cos(self.maximum)) + return xp.arccos(xp.cos(self.minimum) - val / norm) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Defined over [0, pi]. Parameters @@ -481,15 +489,17 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.sin(val) / 2 * self.is_in_prior_range(val) + return xp.sin(val) / 2 * self.is_in_prior_range(val) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): _cdf = ( - (np.cos(val) - np.cos(self.minimum)) - / (np.cos(self.maximum) - np.cos(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) + (xp.cos(val) - xp.cos(self.minimum)) + / (xp.cos(self.maximum) - xp.cos(self.minimum)) ) + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -517,7 +527,8 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Gaussian prior. @@ -527,9 +538,14 @@ def rescale(self, val): This maps to the inverse CDF. This has been analytically solved for this case. """ + if "jax" in xp.__name__: + from jax.scipy.special import erfinv + else: + from scipy.special import erfinv return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -540,9 +556,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the Log prior probability of val. Parameters @@ -553,8 +570,7 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - - return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2)) + return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(2 * np.pi * self.sigma ** 2)) def cdf(self, val): return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 @@ -607,16 +623,22 @@ def normalisation(self): return (erf((self.maximum - self.mu) / 2 ** 0.5 / self.sigma) - erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate truncated Gaussian prior. This maps to the inverse CDF. This has been analytically solved for this case. """ + if "jax" in xp.__name__: + from jax.scipy.special import erfinv + else: + from scipy.special import erfinv return erfinv(2 * val * self.normalisation + erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -627,17 +649,15 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ + return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ / self.sigma / self.normalisation * self.is_in_prior_range(val) def cdf(self, val): - _cdf = ( - ( - erf((val - self.mu) / 2 ** 0.5 / self.sigma) - - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) - ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( + (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -701,15 +721,21 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.sigma = sigma - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate LogNormal prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return np.exp(self.mu + np.sqrt(2 * self.sigma ** 2) * erfinv(2 * val - 1)) + if "jax" in xp.__name__: + from jax.scipy.special import erfinv + else: + from scipy.special import erfinv + return xp.exp(self.mu + (2 * self.sigma ** 2)**0.5 * erfinv(2 * val - 1)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Returns the prior probability of val. Parameters @@ -724,16 +750,17 @@ def prob(self, val): if val <= self.minimum: _prob = 0. else: - _prob = np.exp(-(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val / self.sigma + _prob = xp.exp(-(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ + / xp.sqrt(2 * np.pi) / val / self.sigma else: - _prob = np.zeros(val.size) + _prob = xp.zeros(val.size) idx = (val > self.minimum) - _prob[idx] = np.exp(-(np.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val[idx] / self.sigma + _prob[idx] = xp.exp(-(xp.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ + / xp.sqrt(2 * np.pi) / val[idx] / self.sigma return _prob - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -746,27 +773,28 @@ def ln_prob(self, val): """ if isinstance(val, (float, int)): if val <= self.minimum: - _ln_prob = -np.inf + _ln_prob = -xp.inf else: - _ln_prob = -(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ - - np.log(np.sqrt(2 * np.pi) * val * self.sigma) + _ln_prob = -(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ + - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) else: - _ln_prob = -np.inf * np.ones(val.size) + _ln_prob = -xp.inf * xp.ones(val.size) idx = (val > self.minimum) - _ln_prob[idx] = -(np.log(val[idx]) - self.mu) ** 2\ - / self.sigma ** 2 / 2 - np.log(np.sqrt(2 * np.pi) * val[idx] * self.sigma) + _ln_prob[idx] = -(xp.log(val[idx]) - self.mu) ** 2\ + / self.sigma ** 2 / 2 - xp.log(xp.sqrt(2 * np.pi) * val[idx] * self.sigma) return _ln_prob - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): if isinstance(val, (float, int)): if val <= self.minimum: _cdf = 0. else: _cdf = 0.5 + erf((np.log(val) - self.mu) / self.sigma / np.sqrt(2)) / 2 else: - _cdf = np.zeros(val.size) + _cdf = xp.zeros(val.size) _cdf[val > self.minimum] = 0.5 + erf(( - np.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 + xp.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 return _cdf @@ -795,15 +823,17 @@ def __init__(self, mu, name=None, latex_label=None, unit=None, boundary=None): unit=unit, boundary=boundary) self.mu = mu - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Exponential prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return -self.mu * log1p(-val) + return -self.mu * xp.log1p(-val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -818,13 +848,14 @@ def prob(self, val): if val < self.minimum: _prob = 0. else: - _prob = np.exp(-val / self.mu) / self.mu + _prob = xp.exp(-val / self.mu) / self.mu else: - _prob = np.zeros(val.size) - _prob[val >= self.minimum] = np.exp(-val[val >= self.minimum] / self.mu) / self.mu + _prob = xp.zeros(val.size) + _prob[val >= self.minimum] = xp.exp(-val[val >= self.minimum] / self.mu) / self.mu return _prob - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -837,23 +868,24 @@ def ln_prob(self, val): """ if isinstance(val, (float, int)): if val < self.minimum: - _ln_prob = -np.inf + _ln_prob = -xp.inf else: - _ln_prob = -val / self.mu - np.log(self.mu) + _ln_prob = -val / self.mu - xp.log(self.mu) else: - _ln_prob = -np.inf * np.ones(val.size) - _ln_prob[val >= self.minimum] = -val[val >= self.minimum] / self.mu - np.log(self.mu) + _ln_prob = -xp.inf * xp.ones(val.size) + _ln_prob[val >= self.minimum] = -val[val >= self.minimum] / self.mu - xp.log(self.mu) return _ln_prob - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): if isinstance(val, (float, int)): if val < self.minimum: _cdf = 0. else: - _cdf = 1. - np.exp(-val / self.mu) + _cdf = 1. - xp.exp(-val / self.mu) else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) + _cdf = xp.zeros(val.size) + _cdf[val >= self.minimum] = 1. - xp.exp(-val[val >= self.minimum] / self.mu) return _cdf @@ -891,7 +923,8 @@ def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Student's t-prior. @@ -906,11 +939,12 @@ def rescale(self, val): rescaled = stdtrit(self.df, val) * self.scale + self.mu else: rescaled = stdtrit(self.df, val) * self.scale + self.mu - rescaled[val == 0] = -np.inf - rescaled[val == 1] = np.inf + rescaled[val == 0] = -xp.inf + rescaled[val == 1] = xp.inf return rescaled - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -921,9 +955,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -935,8 +970,8 @@ def ln_prob(self, val): Union[float, array_like]: Prior probability of val """ return gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df)\ - - np.log(np.sqrt(np.pi * self.df) * self.scale) - (self.df + 1) / 2 *\ - np.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + - xp.log((np.pi * self.df)**0.5 * self.scale) - (self.df + 1) / 2 *\ + xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) def cdf(self, val): return stdtr(self.df, (val - self.mu) / self.scale) @@ -988,7 +1023,8 @@ def rescale(self, val): """ return betaincinv(self.alpha, self.beta, val) * (self.maximum - self.minimum) + self.minimum - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -999,9 +1035,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -1012,21 +1049,26 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - _ln_prob = xlogy(self.alpha - 1, val - self.minimum) + xlogy(self.beta - 1, self.maximum - val)\ - - betaln(self.alpha, self.beta) - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) + _ln_prob = ( + xlogy(self.alpha - 1, val - self.minimum) + + xlogy(self.beta - 1, self.maximum - val) + - betaln(self.alpha, self.beta) + - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) + ) # deal with the fact that if alpha or beta are < 1 you get infinities at 0 and 1 if isinstance(val, (float, int)): - if np.isfinite(_ln_prob) and self.minimum <= val <= self.maximum: + if xp.isfinite(_ln_prob) and self.minimum <= val <= self.maximum: return _ln_prob - return -np.inf + return -xp.inf else: - _ln_prob_sub = np.full_like(val, -np.inf) - idx = np.isfinite(_ln_prob) & (val >= self.minimum) & (val <= self.maximum) + _ln_prob_sub = xp.full_like(val, -xp.inf) + idx = xp.isfinite(_ln_prob) & (val >= self.minimum) & (val <= self.maximum) _ln_prob_sub[idx] = _ln_prob[idx] return _ln_prob_sub - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): if isinstance(val, (float, int)): if val > self.maximum: return 1. @@ -1040,8 +1082,9 @@ def cdf(self, val): else: _cdf = np.nan_to_num(betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum))) - _cdf[val < self.minimum] = 0. - _cdf[val > self.maximum] = 1. + _cdf *= val >= self.minimum + _cdf *= val <= self.maximum + _cdf += val > self.maximum return _cdf @@ -1397,7 +1440,8 @@ def rescale(self, val): inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr return -self.sigma * np.log(np.maximum(inv, 0)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -1409,8 +1453,8 @@ def prob(self, val): float: Prior probability of val """ return ( - (np.exp((val - self.mu) / self.sigma) + 1)**-1 - / (self.sigma * np.log1p(self.expr)) + (xp.exp((val - self.mu) / self.sigma) + 1)**-1 + / (self.sigma * xp.log1p(self.expr)) * (val >= self.minimum) ) @@ -1789,3 +1833,5 @@ def cdf(self, val): / (self.mode - self.rescaled_minimum) ) ) + + betaln, diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 3ead934f2..2e3f40df1 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -6,6 +6,7 @@ from warnings import warn import numpy as np +from scipy._lib._array_api import array_namespace from .analytical import DeltaFunction from .base import Prior, Constraint @@ -635,9 +636,9 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - return list( - [self[key].rescale(sample) for key, sample in zip(keys, theta)] - ) + xp = array_namespace(theta) + + return xp.asarray([self[key].rescale(sample) for key, sample in zip(keys, theta)]) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -862,8 +863,9 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ + xp = array_namespace(theta) + keys = list(keys) - theta = list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() @@ -886,9 +888,9 @@ def rescale(self, keys, theta): # {a: [], b: [], c: [1, 2, 3, 4], d: []} # -> [1, 2, 3, 4] # -> {a: 1, b: 2, c: 3, d: 4} - values = list() + values = xp.array([]) for key in names: - values = np.concatenate([values, result[key]]) + values = xp.concatenate([values, result[key]]) for key, value in zip(names, values): result[key] = value @@ -903,7 +905,7 @@ def safe_flatten(value): else: return result[key].flatten() - return [safe_flatten(result[key]) for key in keys] + return xp.array([safe_flatten(result[key]) for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From e816198001744994645ad08fb873ec6dbdd7a36e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 14 Nov 2024 08:29:53 -0800 Subject: [PATCH 007/110] FEAT: make non-numpy arrays serializable --- bilby/core/utils/io.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 8299d6816..a5502a1a6 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -59,8 +59,12 @@ def default(self, obj): return encode_astropy_unit(obj) except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") - if isinstance(obj, np.ndarray): - return {"__array__": True, "content": obj.tolist()} + if hasattr(obj, "__array_namespace__"): + return { + "__array__": True, + "__array_namespace__": obj.__array_namespace__().__name__, + "content": obj.tolist(), + } if isinstance(obj, complex): return {"__complex__": True, "real": obj.real, "imag": obj.imag} if isinstance(obj, pd.DataFrame): @@ -320,7 +324,9 @@ def decode_bilby_json(dct): if dct.get("__astropy_unit__", False): return decode_astropy_unit(dct) if dct.get("__array__", False): - return np.asarray(dct["content"]) + namespace = dct.get("__array_namespace__", "numpy") + xp = import_module(namespace) + return xp.asarray(dct["content"]) if dct.get("__complex__", False): return complex(dct["real"], dct["imag"]) if dct.get("__dataframe__", False): @@ -438,6 +444,10 @@ def encode_for_hdf5(key, item): if item.dtype.kind == 'U': logger.debug(f'converting dtype {item.dtype} for hdf5') item = np.array(item, dtype='S') + elif hasattr(item, "__array_namespace__"): + # temporarily dump all arrays as numpy arrays, we should figure ou + # how to properly deserialize them + item = np.asarray(item) if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): output = item elif isinstance(item, np.random.Generator): From 7a785c2ff6c2790e27652dbdbf01a0ba79c7dd4a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 14 Nov 2024 08:30:18 -0800 Subject: [PATCH 008/110] BUG: fix some array conversion methods --- bilby/gw/conversion.py | 4 ++-- bilby/gw/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 093d29ab6..ca47dec57 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -9,7 +9,6 @@ import pickle import numpy as np -from bilback.utils import array_module from pandas import DataFrame, Series from scipy.stats import norm @@ -27,6 +26,7 @@ lalsim_SimNeutronStarRadius, lalsim_SimNeutronStarLoveNumberK2) +from ..compat.utils import array_module from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, gravitational_constant, speed_of_light, command_line_args, safe_file_dump from ..core.prior import DeltaFunction @@ -241,7 +241,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): """ converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) - xp = array_module(parameters[original_keys[0]]) + xp = array_module(parameters[original_keys[5]]) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 719664635..fcbaf2a86 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -9,7 +9,7 @@ from .geometry import zenith_azimuth_to_theta_phi from .time import greenwich_mean_sidereal_time -from ...compat.utils import array_module +from ..compat.utils import array_module from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) From 7ebf34044bfbf1cbfadac749da788415482a9446 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 11 Dec 2024 23:13:20 +0000 Subject: [PATCH 009/110] DEV: some more prior agnosticism --- bilby/core/prior/analytical.py | 151 ++++++------------------------- bilby/core/prior/interpolated.py | 4 +- bilby/core/prior/joint.py | 20 ++-- test/core/prior/prior_test.py | 23 +++++ 4 files changed, 67 insertions(+), 131 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 974d5cbe4..713c27fb3 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -7,7 +7,6 @@ betaincinv, betaln, erf, - # erfinv, # erfinv is not currently backend agnostic gammaincinv, gammainc, gammaln, @@ -176,7 +175,7 @@ def cdf(self, val, *, xp=np): _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) else: _cdf = ( - val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1) + (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) ) _cdf = xp.clip(_cdf, 0, 1) @@ -332,19 +331,7 @@ def rescale(self, val, *, xp=np): ======= Union[float, array_like]: Rescaled probability """ - if isinstance(val, (float, int)): - if val < 0.5: - return -self.maximum * xp.exp(-2 * val * xp.log(self.maximum / self.minimum)) - else: - return self.minimum * xp.exp(xp.log(self.maximum / self.minimum) * (2 * val - 1)) - else: - vals_less_than_5 = val < 0.5 - rescaled = xp.empty_like(val) - rescaled[vals_less_than_5] = -self.maximum * xp.exp(-2 * val[vals_less_than_5] * - xp.log(self.maximum / self.minimum)) - rescaled[~vals_less_than_5] = self.minimum * xp.exp(xp.log(self.maximum / self.minimum) * - (2 * val[~vals_less_than_5] - 1)) - return rescaled + return xp.sign(2 * val - 1) * self.minimum * xp.exp(abs(2 * val - 1) * xp.log(self.maximum / self.minimum)) @xp_wrap def prob(self, val, *, xp=np): @@ -379,7 +366,7 @@ def ln_prob(self, val, *, xp=np): @xp_wrap def cdf(self, val, *, xp=np): - asymmetric = LogUniform.cdf(self, abs(val), xp) + asymmetric = xp.log(abs(val) / self.minimum) / xp.log(self.maximum / self.minimum) return 0.5 * (1 + xp.sign(val) * asymmetric) @@ -746,18 +733,10 @@ def prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _prob = 0. - else: - _prob = xp.exp(-(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / xp.sqrt(2 * np.pi) / val / self.sigma - else: - _prob = xp.zeros(val.size) - idx = (val > self.minimum) - _prob[idx] = xp.exp(-(xp.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / xp.sqrt(2 * np.pi) / val[idx] / self.sigma - return _prob + return ( + xp.exp(-(xp.log(xp.maximum(val, self.minimum)) - self.mu) ** 2 / self.sigma ** 2 / 2) + / xp.sqrt(2 * np.pi) / val / self.sigma + ) * (val > self.minimum) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -771,31 +750,16 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val <= self.minimum: - _ln_prob = -xp.inf - else: - _ln_prob = -(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ - - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) - else: - _ln_prob = -xp.inf * xp.ones(val.size) - idx = (val > self.minimum) - _ln_prob[idx] = -(xp.log(val[idx]) - self.mu) ** 2\ - / self.sigma ** 2 / 2 - xp.log(xp.sqrt(2 * np.pi) * val[idx] * self.sigma) - return _ln_prob + return ( + -(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2 + - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) + ) + xp.log(val > self.minimum) @xp_wrap def cdf(self, val, *, xp=np): - if isinstance(val, (float, int)): - if val <= self.minimum: - _cdf = 0. - else: - _cdf = 0.5 + erf((np.log(val) - self.mu) / self.sigma / np.sqrt(2)) / 2 - else: - _cdf = xp.zeros(val.size) - _cdf[val > self.minimum] = 0.5 + erf(( - xp.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 - return _cdf + return 0.5 + erf( + (xp.log(xp.maximum(val, self.minimum)) - self.mu) / self.sigma / np.sqrt(2) + ) / 2 class LogGaussian(LogNormal): @@ -844,15 +808,7 @@ def prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _prob = 0. - else: - _prob = xp.exp(-val / self.mu) / self.mu - else: - _prob = xp.zeros(val.size) - _prob[val >= self.minimum] = xp.exp(-val[val >= self.minimum] / self.mu) / self.mu - return _prob + return xp.exp(-val / self.mu) / self.mu * (val >= self.minimum) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -866,27 +822,11 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -xp.inf - else: - _ln_prob = -val / self.mu - xp.log(self.mu) - else: - _ln_prob = -xp.inf * xp.ones(val.size) - _ln_prob[val >= self.minimum] = -val[val >= self.minimum] / self.mu - xp.log(self.mu) - return _ln_prob + return -val / self.mu - xp.log(self.mu) + xp.log(val > self.minimum) @xp_wrap def cdf(self, val, *, xp=np): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = 1. - xp.exp(-val / self.mu) - else: - _cdf = xp.zeros(val.size) - _cdf[val >= self.minimum] = 1. - xp.exp(-val[val >= self.minimum] / self.mu) - return _cdf + return xp.maximum(1. - xp.exp(-val / self.mu), 0) class StudentT(Prior): @@ -930,18 +870,11 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - else: - rescaled = stdtrit(self.df, val) * self.scale + self.mu - rescaled[val == 0] = -xp.inf - rescaled[val == 1] = xp.inf - return rescaled + return ( + xp.nan_to_num(stdtrit(self.df, val) * self.scale + self.mu) + + xp.log(val > 0) + - xp.log(val < 1) + ) @xp_wrap def prob(self, val, *, xp=np): @@ -1117,25 +1050,14 @@ def __init__(self, mu, scale, name=None, latex_label=None, unit=None, boundary=N self.mu = mu self.scale = scale - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Logistic prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - if isinstance(val, (float, int)): - if val == 0: - rescaled = -np.inf - elif val == 1: - rescaled = np.inf - else: - rescaled = self.mu + self.scale * np.log(val / (1. - val)) - else: - rescaled = np.inf * np.ones(val.size) - rescaled[val == 0] = -np.inf - rescaled[(val > 0) & (val < 1)] = self.mu + self.scale\ - * np.log(val[(val > 0) & (val < 1)] / (1. - val[(val > 0) & (val < 1)])) - return rescaled + return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) def prob(self, val): """Return the prior probability of val. @@ -1197,21 +1119,14 @@ def __init__(self, alpha, beta, name=None, latex_label=None, unit=None, boundary self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Cauchy prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - rescaled = self.alpha + self.beta * np.tan(np.pi * (val - 0.5)) - if isinstance(val, (float, int)): - if val == 1: - rescaled = np.inf - elif val == 0: - rescaled = -np.inf - else: - rescaled[val == 1] = np.inf - rescaled[val == 0] = -np.inf + rescaled = self.alpha + self.beta * xp.tan(np.pi * (val - 0.5)) return rescaled def prob(self, val): @@ -1323,15 +1238,7 @@ def ln_prob(self, val): return _ln_prob def cdf(self, val): - if isinstance(val, (float, int)): - if val < self.minimum: - _cdf = 0. - else: - _cdf = gammainc(self.k, val / self.theta) - else: - _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = gammainc(self.k, val[val >= self.minimum] / self.theta) - return _cdf + return gammainc(self.k, xp.maximum(val, self.minimum) / self.theta) class ChiSquared(Gamma): diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 5fbf8f9c1..d47f14209 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -3,6 +3,7 @@ from .base import Prior from ..utils import logger, WrappedInterp1d as interp1d +from ...compat.utils import xp_wrap class Interped(Prior): @@ -80,7 +81,8 @@ def prob(self, val): def cdf(self, val): return self.cumulative_distribution(val) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the prior. diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..06a740497 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -7,6 +7,7 @@ from .base import Prior, PriorException from ..utils import logger, infer_args_from_method, get_dict_with_properties from ..utils import random +from ...compat.utils import xp_wrap class BaseJointPriorDist(object): @@ -295,7 +296,8 @@ def _sample(self, size, **kwargs): """ return samps - def rescale(self, value, **kwargs): + @xp_wrap + def rescale(self, value, *, xp=np, **kwargs): """ Rescale from a unit hypercube to JointPriorDist. Note that no bounds are applied in the rescale function. (child classes need to @@ -317,7 +319,7 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + samp = xp.array(value) if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,7 +329,7 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) - return np.squeeze(samp) + return xp.squeeze(samp) def _rescale(self, samp, **kwargs): """ @@ -611,7 +613,9 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1]) ) - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=np, **kwargs): + print(samp, xp) try: mode = kwargs["mode"] except KeyError: @@ -626,7 +630,7 @@ def _rescale(self, samp, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( + samp = self.mus[mode] + self.sigmas[mode] * xp.einsum( "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] ) return samp @@ -778,7 +782,8 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) - def rescale(self, val, **kwargs): + @xp_wrap + def rescale(self, val, *, xp=np, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -793,11 +798,10 @@ def rescale(self, val, **kwargs): float: A sample from the prior parameter. """ - self.dist.rescale_parameters[self.name] = val if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T + values = xp.array(list(self.dist.rescale_parameters.values())).T samples = self.dist.rescale(values, **kwargs) self.dist.reset_rescale() return samples diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 17d360d0c..0643dc4df 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -848,6 +848,29 @@ def test_set_minimum_setting(self): continue prior.minimum = (prior.maximum + prior.minimum) / 2 self.assertTrue(min(prior.sample(10000)) > prior.minimum) + + def test_jax_rescale(self): + import jax + + points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) + for prior in self.priors: + if isinstance( + prior, ( + bilby.core.prior.StudentT, + bilby.core.prior.Beta, + bilby.core.prior.Gamma, + ), + ) or bilby.core.prior.JointPrior in prior.__class__.__mro__: + continue + print(prior) + scaled = prior.rescale(points) + assert isinstance(scaled, jax.Array) + if isinstance(prior, bilby.core.prior.DeltaFunction): + continue + assert max(abs(prior.cdf(scaled) - points)) < 1e-6 + probs = prior.prob(scaled) + assert min(probs) > 0 + assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 if __name__ == "__main__": From b558ea672124397295cb20e51cbba73deee0af30 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 12 Dec 2024 22:45:57 +0000 Subject: [PATCH 010/110] TEST: make all prior tests run This required making some changes to the tests for conditional dicts as I've changed the output types and the backend introspection doesn't work on dict_items for some reason --- bilby/core/prior/analytical.py | 154 ++++++++++++++-------------- bilby/core/prior/conditional.py | 1 + bilby/core/prior/dict.py | 34 +----- bilby/core/prior/slabspike.py | 65 +++++------- test/core/prior/conditional_test.py | 9 +- test/core/prior/prior_test.py | 35 +++---- test/core/prior/slabspike_test.py | 13 +++ 7 files changed, 140 insertions(+), 171 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 713c27fb3..876196ec6 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -172,7 +172,8 @@ def ln_prob(self, val, *, xp=np): @xp_wrap def cdf(self, val, *, xp=np): if self.alpha == -1: - _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) + with np.errstate(invalid="ignore"): + _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) else: _cdf = ( (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) @@ -733,10 +734,7 @@ def prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return ( - xp.exp(-(xp.log(xp.maximum(val, self.minimum)) - self.mu) ** 2 / self.sigma ** 2 / 2) - / xp.sqrt(2 * np.pi) / val / self.sigma - ) * (val > self.minimum) + return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -750,16 +748,18 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return ( - -(xp.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2 - - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) - ) + xp.log(val > self.minimum) + with np.errstate(divide="ignore", invalid="ignore"): + return xp.nan_to_num(( + -(xp.log(xp.maximum(val, self.minimum)) - self.mu) ** 2 / self.sigma ** 2 / 2 + - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) + ) + xp.log(val > self.minimum), nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap def cdf(self, val, *, xp=np): - return 0.5 + erf( - (xp.log(xp.maximum(val, self.minimum)) - self.mu) / self.sigma / np.sqrt(2) - ) / 2 + with np.errstate(divide="ignore"): + return 0.5 + erf( + (xp.log(xp.maximum(val, self.minimum)) - self.mu) / self.sigma / np.sqrt(2) + ) / 2 class LogGaussian(LogNormal): @@ -794,7 +794,8 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - return -self.mu * xp.log1p(-val) + with np.errstate(divide="ignore", over="ignore"): + return -self.mu * xp.log1p(-val) @xp_wrap def prob(self, val, *, xp=np): @@ -808,7 +809,7 @@ def prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return xp.exp(-val / self.mu) / self.mu * (val >= self.minimum) + return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -822,11 +823,13 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return -val / self.mu - xp.log(self.mu) + xp.log(val > self.minimum) + with np.errstate(divide="ignore"): + return -val / self.mu - xp.log(self.mu) + xp.log(val >= self.minimum) @xp_wrap def cdf(self, val, *, xp=np): - return xp.maximum(1. - xp.exp(-val / self.mu), 0) + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + return xp.maximum(1. - xp.exp(-val / self.mu), 0) class StudentT(Prior): @@ -869,12 +872,17 @@ def rescale(self, val, *, xp=np): 'Rescale' a sample from the unit line element to the appropriate Student's t-prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - return ( - xp.nan_to_num(stdtrit(self.df, val) * self.scale + self.mu) - + xp.log(val > 0) - - xp.log(val < 1) - ) + with np.errstate(divide="ignore", invalid="ignore"): + return ( + xp.nan_to_num(stdtrit(self.df, val) * self.scale + self.mu) + + xp.log(val > 0) + - xp.log(val < 1) + ) @xp_wrap def prob(self, val, *, xp=np): @@ -902,9 +910,11 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df)\ - - xp.log((np.pi * self.df)**0.5 * self.scale) - (self.df + 1) / 2 *\ - xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + return ( + gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df) + - xp.log((np.pi * self.df)**0.5 * self.scale) - (self.df + 1) / 2 + * xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + ) def cdf(self, val): return stdtr(self.df, (val - self.mu) / self.scale) @@ -948,13 +958,21 @@ def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, self.alpha = alpha self.beta = beta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Beta prior. This maps to the inverse CDF. This has been analytically solved for this case. + + Notes + ===== + This explicitly casts to the requested backend, but the computation will be done by scipy. """ - return betaincinv(self.alpha, self.beta, val) * (self.maximum - self.minimum) + self.minimum + return ( + xp.asarray(betaincinv(self.alpha, self.beta, val)) * (self.maximum - self.minimum) + + self.minimum + ) @xp_wrap def prob(self, val, *, xp=np): @@ -983,42 +1001,18 @@ def ln_prob(self, val, *, xp=np): Union[float, array_like]: Prior probability of val """ _ln_prob = ( - xlogy(self.alpha - 1, val - self.minimum) - + xlogy(self.beta - 1, self.maximum - val) + xlogy(xp.asarray(self.alpha - 1), val - self.minimum) + + xlogy(xp.asarray(self.beta - 1), self.maximum - val) - betaln(self.alpha, self.beta) - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) ) - - # deal with the fact that if alpha or beta are < 1 you get infinities at 0 and 1 - if isinstance(val, (float, int)): - if xp.isfinite(_ln_prob) and self.minimum <= val <= self.maximum: - return _ln_prob - return -xp.inf - else: - _ln_prob_sub = xp.full_like(val, -xp.inf) - idx = xp.isfinite(_ln_prob) & (val >= self.minimum) & (val <= self.maximum) - _ln_prob_sub[idx] = _ln_prob[idx] - return _ln_prob_sub + return xp.nan_to_num(_ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap def cdf(self, val, *, xp=np): - if isinstance(val, (float, int)): - if val > self.maximum: - return 1. - elif val < self.minimum: - return 0. - else: - return betainc( - self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum) - ) - else: - _cdf = np.nan_to_num(betainc(self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum))) - _cdf *= val >= self.minimum - _cdf *= val <= self.maximum - _cdf += val > self.maximum - return _cdf + return xp.nan_to_num( + betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum)) + ) + (val > self.maximum) class Logistic(Prior): @@ -1057,7 +1051,9 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) + with np.errstate(divide="ignore"): + val = xp.asarray(val) + return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) def prob(self, val): """Return the prior probability of val. @@ -1072,7 +1068,8 @@ def prob(self, val): """ return np.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -1083,8 +1080,9 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return -(val - self.mu) / self.scale -\ - 2. * np.log(1. + np.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) + with np.errstate(over="ignore"): + return -(val - self.mu) / self.scale -\ + 2. * np.log1p(xp.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) def cdf(self, val): return 1. / (1. + np.exp(-(val - self.mu) / self.scale)) @@ -1127,7 +1125,8 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ rescaled = self.alpha + self.beta * xp.tan(np.pi * (val - 0.5)) - return rescaled + with np.errstate(divide="ignore", invalid="ignore"): + return rescaled - xp.log(val < 1) + xp.log(val > 0) def prob(self, val): """Return the prior probability of val. @@ -1193,15 +1192,17 @@ def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary self.k = k self.theta = theta - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Gamma prior. This maps to the inverse CDF. This has been analytically solved for this case. """ - return gammaincinv(self.k, val) * self.theta + return xp.asarray(gammaincinv(self.k, val)) * self.theta - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -1212,9 +1213,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Returns the log prior probability of val. Parameters @@ -1225,20 +1227,16 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - if isinstance(val, (float, int)): - if val < self.minimum: - _ln_prob = -np.inf - else: - _ln_prob = xlogy(self.k - 1, val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - else: - _ln_prob = -np.inf * np.ones(val.size) - idx = (val >= self.minimum) - _ln_prob[idx] = xlogy(self.k - 1, val[idx]) - val[idx] / self.theta\ + with np.errstate(divide="ignore"): + ln_prob = ( + xlogy(xp.asarray(self.k - 1), val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) - return _ln_prob + ) + xp.log(val >= self.minimum) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=xp.inf) - def cdf(self, val): - return gammainc(self.k, xp.maximum(val, self.minimum) / self.theta) + @xp_wrap + def cdf(self, val, *, xp=np): + return gammainc(xp.asarray(self.k), xp.maximum(val, self.minimum) / self.theta) class ChiSquared(Gamma): diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 7c2a739e2..5e34e70f9 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -164,6 +164,7 @@ class depending on the required variables it depends on. self.reference_params will be used. """ + required_variables.pop("xp", None) if sorted(list(required_variables)) == sorted(self.required_variables): parameters = self.condition_func(self.reference_params.copy(), **required_variables) for key, value in parameters.items(): diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 2e3f40df1..09480a099 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -636,6 +636,8 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ + if isinstance(theta, {}.values().__class__): + theta = list(theta) xp = array_namespace(theta) return xp.asarray([self[key].rescale(sample) for key, sample in zip(keys, theta)]) @@ -863,6 +865,8 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ + if isinstance(theta, {}.values().__class__): + theta = list(theta) xp = array_namespace(theta) keys = list(keys) @@ -877,35 +881,7 @@ def rescale(self, keys, theta): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: - joint[self[key].dist.distname] = [key] - elif isinstance(self[key], JointPrior): - joint[self[key].dist.distname].append(key) - for names in joint.values(): - # this is needed to unpack how joint prior rescaling works - # as an example of a joint prior over {a, b, c, d} we might - # get the following based on the order within the joint prior - # {a: [], b: [], c: [1, 2, 3, 4], d: []} - # -> [1, 2, 3, 4] - # -> {a: 1, b: 2, c: 3, d: 4} - values = xp.array([]) - for key in names: - values = xp.concatenate([values, result[key]]) - for key, value in zip(names, values): - result[key] = value - - def safe_flatten(value): - """ - this is gross but can be removed whenever we switch to returning - arrays, flatten converts 0-d arrays to 1-d so has to be special - cased - """ - if isinstance(value, (float, int)): - return value - else: - return result[key].flatten() - - return xp.array([safe_flatten(result[key]) for key in keys]) + return xp.concatenate([result[key] for key in keys], axis=None) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 6910be608..ff823a369 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -1,8 +1,8 @@ -from numbers import Number import numpy as np from .base import Prior from ..utils import logger +from ...compat.utils import xp_wrap class SlabSpikePrior(Prior): @@ -72,7 +72,8 @@ def slab_fraction(self): def _find_inverse_cdf_fraction_before_spike(self): return float(self.slab.cdf(self.spike_location)) * self.slab_fraction - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the prior. @@ -85,25 +86,19 @@ def rescale(self, val): ======= array_like: Associated prior value with input value. """ - original_is_number = isinstance(val, Number) - val = np.atleast_1d(val) - lower_indices = val < self.inverse_cdf_below_spike - intermediate_indices = np.logical_and( - self.inverse_cdf_below_spike <= val, - val <= (self.inverse_cdf_below_spike + self.spike_height)) - higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) - - res = np.zeros(len(val)) - res[lower_indices] = self._contracted_rescale(val[lower_indices]) - res[intermediate_indices] = self.spike_location - res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + intermediate_indices = ( + (self.inverse_cdf_below_spike <= val) + * (val < (self.inverse_cdf_below_spike + self.spike_height)) + ) + higher_indices = val >= (self.inverse_cdf_below_spike + self.spike_height) + + slab_scaled = self._contracted_rescale(val - self.spike_height * higher_indices) + + res = xp.select( + [lower_indices | higher_indices, intermediate_indices], + [slab_scaled, self.spike_location], + ) return res def _contracted_rescale(self, val): @@ -122,7 +117,8 @@ def _contracted_rescale(self, val): """ return self.slab.rescale(val / self.slab_fraction) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Returns np.inf for the spike location @@ -134,19 +130,13 @@ def prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.prob(val) * self.slab_fraction - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(invalid="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the Log prior probability of val. Returns np.inf for the spike location @@ -158,16 +148,9 @@ def ln_prob(self, val): ======= array_like: Prior probability of val """ - original_is_number = isinstance(val, Number) res = self.slab.ln_prob(val) + np.log(self.slab_fraction) - res = np.atleast_1d(res) - res[val == self.spike_location] = np.inf - if original_is_number: - try: - res = res[0] - except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + with np.errstate(divide="ignore"): + res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res def cdf(self, val): @@ -185,5 +168,5 @@ def cdf(self, val): """ res = self.slab.cdf(val) * self.slab_fraction - res += self.spike_height * (val > self.spike_location) + res += (val > self.spike_location) * self.spike_height return res diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..68db12ed7 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -324,7 +324,7 @@ def test_rescale(self): expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res) + np.testing.assert_array_equal(expected, res) def test_rescale_with_joint_prior(self): """ @@ -353,7 +353,6 @@ def test_rescale_with_joint_prior(self): keys = list(self.test_sample.keys()) + names res = priordict.rescale(keys=keys, theta=ref_variables) - self.assertIsInstance(res, list) self.assertEqual(np.shape(res), (6,)) self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) @@ -361,7 +360,7 @@ def test_rescale_with_joint_prior(self): expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + np.testing.assert_array_equal(expected, res[:4]) def test_cdf(self): """ @@ -370,11 +369,11 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ sample = self.conditional_priors.sample() - self.assertEqual( + np.testing.assert_array_equal( self.conditional_priors.rescale( sample.keys(), self.conditional_priors.cdf(sample=sample).values() - ), list(sample.values()) + ), np.array(list(sample.values())) ) def test_rescale_illegal_conditions(self): diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 0643dc4df..14f864e90 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -564,10 +564,14 @@ def test_probability_in_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.minimum == -np.inf: - prior.minimum = -1e5 + minimum = -1e5 + else: + minimum = prior.minimum if prior.maximum == np.inf: - prior.maximum = 1e5 - domain = np.linspace(prior.minimum, prior.maximum, 1000) + maximum = 1e5 + else: + maximum = prior.maximum + domain = np.linspace(minimum, maximum, 1000) self.assertTrue(all(prior.prob(domain) >= 0)) def test_probability_surrounding_domain(self): @@ -579,13 +583,14 @@ def test_probability_surrounding_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) - indomain = (surround_domain >= prior.minimum) | ( - surround_domain <= prior.maximum - ) - outdomain = (surround_domain < prior.minimum) | ( - surround_domain > prior.maximum - ) + with np.errstate(invalid="ignore"): + surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) + indomain = (surround_domain >= prior.minimum) | ( + surround_domain <= prior.maximum + ) + outdomain = (surround_domain < prior.minimum) | ( + surround_domain > prior.maximum + ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): continue @@ -849,18 +854,12 @@ def test_set_minimum_setting(self): prior.minimum = (prior.maximum + prior.minimum) / 2 self.assertTrue(min(prior.sample(10000)) > prior.minimum) - def test_jax_rescale(self): + def test_jax_methods(self): import jax points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) for prior in self.priors: - if isinstance( - prior, ( - bilby.core.prior.StudentT, - bilby.core.prior.Beta, - bilby.core.prior.Gamma, - ), - ) or bilby.core.prior.JointPrior in prior.__class__.__mro__: + if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue print(prior) scaled = prior.rescale(points) diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index d2cdcc55a..8cb2fcf1d 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -102,6 +102,19 @@ def tearDown(self): del self.test_nodes_finite_support del self.test_nodes_infinite_support + def test_jax_methods(self): + import jax + + points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) + for prior in self.slab_spikes: + scaled = prior.rescale(points) + assert isinstance(scaled, jax.Array) + if isinstance(prior, bilby.core.prior.DeltaFunction): + continue + probs = prior.prob(scaled) + assert min(probs) > 0 + assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 + def test_prob_on_slab(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): expected = slab.prob(test_nodes) * slab_spike.slab_fraction From af8d604a02fa8aad6b22b022d19da881a73c0437 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 25 Jan 2025 03:24:23 -0800 Subject: [PATCH 011/110] DEV: move some jax functionality to compat --- bilby/compat/jax.py | 73 +++++++++++++++++++++++++++++++++++++++++ bilby/gw/jaxstuff.py | 78 -------------------------------------------- 2 files changed, 73 insertions(+), 78 deletions(-) create mode 100644 bilby/compat/jax.py diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py new file mode 100644 index 000000000..8d297487b --- /dev/null +++ b/bilby/compat/jax.py @@ -0,0 +1,73 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from ..core.likelihood import Likelihood + + +def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): + """ + A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to evaluate. + parameters: dict + The parameters to evaluate the likelihood at. + use_ratio: bool, optional + Whether to evaluate the likelihood ratio or the full likelihood. + Default is :code:`True`. + """ + parameters = {k: jnp.array(v) for k, v in parameters.items()} + likelihood.parameters.update(parameters) + if use_ratio: + return likelihood.log_likelihood_ratio() + else: + return likelihood.log_likelihood() + + +class JittedLikelihood(Likelihood): + """ + A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. + + .. note:: + + This is currently hardcoded to return the log likelihood ratio, regardless of + the input. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood + The likelihood to wrap. + likelihood_func: callable, optional + The function to use to evaluate the likelihood. Default is + :code:`generic_bilby_likelihood_function`. This function should take the + likelihood and parameters as arguments along with additional keyword arguments. + kwargs: dict, optional + Additional keyword arguments to pass to the likelihood function. + """ + + def __init__( + self, + likelihood, + likelihood_func=generic_bilby_likelihood_function, + kwargs=None, + cast_to_float=True, + ): + if kwargs is None: + kwargs = dict() + self.kwargs = kwargs + self._likelihood = likelihood + self.likelihood_func = jax.jit(partial(likelihood_func, likelihood)) + self.cast_to_float = cast_to_float + super().__init__(dict()) + + def __getattr__(self, name): + return getattr(self._likelihood, name) + + def log_likelihood_ratio(self): + ln_l = jnp.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l diff --git a/bilby/gw/jaxstuff.py b/bilby/gw/jaxstuff.py index f1e1c57b0..6046e51bd 100644 --- a/bilby/gw/jaxstuff.py +++ b/bilby/gw/jaxstuff.py @@ -4,14 +4,8 @@ idea of how much pain is being added. """ -from functools import partial - -from bilby.core.likelihood import Likelihood - import jax import jax.numpy as jnp -from plum import dispatch -from jax.scipy.special import i0e from ripple.waveforms import IMRPhenomPv2 @@ -71,75 +65,3 @@ def ripple_bbh( hp, hc = wf_func(frequencies, theta, jax.numpy.array(20.0)) return dict(plus=hp, cross=hc) - -def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): - """ - A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`. - - Parameters - ========== - likelihood: bilby.core.likelihood.Likelihood - The likelihood to evaluate. - parameters: dict - The parameters to evaluate the likelihood at. - use_ratio: bool, optional - Whether to evaluate the likelihood ratio or the full likelihood. - Default is :code:`True`. - """ - parameters = {k: jnp.array(v) for k, v in parameters.items()} - likelihood.parameters.update(parameters) - if use_ratio: - return likelihood.log_likelihood_ratio() - else: - return likelihood.log_likelihood() - - -class JittedLikelihood(Likelihood): - """ - A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. - - .. note:: - - This is currently hardcoded to return the log likelihood ratio, regardless of - the input. - - Parameters - ========== - likelihood: bilby.core.likelihood.Likelihood - The likelihood to wrap. - likelihood_func: callable, optional - The function to use to evaluate the likelihood. Default is - :code:`generic_bilby_likelihood_function`. This function should take the - likelihood and parameters as arguments along with additional keyword arguments. - kwargs: dict, optional - Additional keyword arguments to pass to the likelihood function. - """ - - def __init__( - self, - likelihood, - likelihood_func=generic_bilby_likelihood_function, - kwargs=None, - cast_to_float=True, - ): - if kwargs is None: - kwargs = dict() - self.kwargs = kwargs - self._likelihood = likelihood - self.likelihood_func = jax.jit(partial(likelihood_func, likelihood)) - self.cast_to_float = cast_to_float - super().__init__(dict()) - - def __getattr__(self, name): - return getattr(self._likelihood, name) - - def log_likelihood_ratio(self): - ln_l = jnp.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) - if self.cast_to_float: - ln_l = float(ln_l) - return ln_l - - -@dispatch -def ln_i0(value: jax.Array): - return jnp.log(i0e(value)) + jnp.abs(value) From 5b5fa6b75779a66517f098807c93bd84f527f166 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 25 Jan 2025 03:25:55 -0800 Subject: [PATCH 012/110] REFACTOR: use array backend for ln_i0 --- bilby/gw/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index fcbaf2a86..f2a7499e4 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -1003,7 +1003,6 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= plt.xlim(freq_points.min() - .5, freq_points.max() + 50) -@dispatch def ln_i0(value): """ A numerically stable method to evaluate ln(I_0) a modified Bessel function @@ -1019,7 +1018,8 @@ def ln_i0(value): array-like: The natural logarithm of the bessel function """ - return np.log(i0e(value)) + np.abs(value) + xp = array_module(value) + return xp.log(i0e(value)) + xp.abs(value) def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): From c52be6995e032a23caabfc3fd7e3ad1dbdea9f25 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 25 Jan 2025 03:28:34 -0800 Subject: [PATCH 013/110] make distance marginalizatio backend transparent --- bilby/gw/likelihood/base.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index d98d4f481..a63e5d1bb 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -780,12 +780,12 @@ def distance_marginalized_likelihood(self, d_inner_h, h_inner_h, parameters=None d_inner_h_ref, h_inner_h_ref = self._setup_rho( d_inner_h, h_inner_h, parameters=parameters) if self.phase_marginalization: - d_inner_h_ref = np.abs(d_inner_h_ref) + d_inner_h_ref = abs(d_inner_h_ref) else: - d_inner_h_ref = np.real(d_inner_h_ref) + d_inner_h_ref = d_inner_h_ref.real return self._interp_dist_margd_loglikelihood( - d_inner_h_ref, h_inner_h_ref, grid=False) + d_inner_h_ref, h_inner_h_ref) def phase_marginalized_likelihood(self, d_inner_h, h_inner_h): d_inner_h = ln_i0(abs(d_inner_h)) @@ -933,9 +933,19 @@ def _setup_distance_marginalization(self, lookup_table=None): self._create_lookup_table() else: self._create_lookup_table() - self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( - self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, - self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) + if "jax" in array_module(self.interferometers.frequency_array).__name__: + from interpax import Interpolator2D + import jax.numpy as jnp + self._interp_dist_margd_loglikelihood = Interpolator2D( + jnp.asarray(self._d_inner_h_ref_array), + jnp.asarray(self._optimal_snr_squared_ref_array), + jnp.asarray(self._dist_margd_loglikelihood_array.T), + extrap=-jnp.inf, + ) + else: + self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( + self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, + self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) @property def cached_lookup_table_filename(self): @@ -1107,8 +1117,10 @@ def get_sky_frame_parameters(self, parameters=None): ======= dict: dictionary containing ra, dec, and geocent_time """ + from ..conversion import convert_orientation_quaternion, convert_cartesian parameters = _fallback_to_parameters(self, parameters) - convert_orientation_quaternion(parameters) + if "orientation_w" in parameters: + convert_orientation_quaternion(parameters) time = parameters.get(f'{self.time_reference}_time', None) if time is None and "geocent_time" in parameters: logger.warning( From 025e3d53feabfb31838a9825562db8dac30eb9ff Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 25 Jan 2025 03:29:50 -0800 Subject: [PATCH 014/110] DEV: some more prior dict array refactoring --- bilby/core/prior/dict.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 09480a099..7a9a8dcfd 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -55,6 +55,9 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): else: self.conversion_function = self.default_conversion_function + def __hash__(self): + return hash(str(self)) + def evaluate_constraints(self, sample): out_sample = self.conversion_function(sample) try: @@ -539,9 +542,11 @@ def prob(self, sample, **kwargs): float: Joint probability of all individual sample probabilities """ - prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) + xp = array_namespace(*sample.values()) + prob = xp.prod(xp.asarray([self[key].prob(sample[key]) for key in sample]), **kwargs) - return self.check_prob(sample, prob) + return prob + # return self.check_prob(sample, prob) def check_prob(self, sample, prob): ratio = self.normalize_constraint_factor(tuple(sample.keys())) @@ -809,12 +814,14 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ + xp = array_namespace(*sample.values()) + res = xp.asarray([ self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample - ] - prob = np.prod(res, **kwargs) - return self.check_prob(sample, prob) + ]) + prob = xp.prod(res, **kwargs) + return prob + # return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None, normalized=True): """ @@ -835,13 +842,15 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ + xp = array_namespace(*sample.values()) + res = xp.array([ self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample - ] - ln_prob = np.sum(res, axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + ]) + ln_prob = xp.sum(res, axis=axis) + return ln_prob + # return self.check_ln_prob(sample, ln_prob, + # normalized=normalized) def cdf(self, sample): self._prepare_evaluation(*zip(*sample.items())) @@ -881,7 +890,7 @@ def rescale(self, keys, theta): theta[index], **self.get_required_variables(key) ) self[key].least_recently_sampled = result[key] - return xp.concatenate([result[key] for key in keys], axis=None) + return xp.array([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 47baf2c698307486c72f0eff03040a23db7a60e4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 29 Jan 2025 13:12:59 -0800 Subject: [PATCH 015/110] fix jax logic for distance marginalization --- bilby/gw/likelihood/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index a63e5d1bb..3383110aa 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -933,7 +933,7 @@ def _setup_distance_marginalization(self, lookup_table=None): self._create_lookup_table() else: self._create_lookup_table() - if "jax" in array_module(self.interferometers.frequency_array).__name__: + if "jax" in array_module(self.interferometers[0].vertex).__name__: from interpax import Interpolator2D import jax.numpy as jnp self._interp_dist_margd_loglikelihood = Interpolator2D( From b501d83df95c6a8b9d2b6502dba276484c93f5f7 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 29 Jan 2025 13:13:40 -0800 Subject: [PATCH 016/110] improve efficiency of setting up multibanding --- bilby/gw/likelihood/multiband.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 46d29c915..d7ffabfde 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -552,7 +552,7 @@ def _setup_quadratic_coefficients_linear_interp(self): linear-interpolation algorithm""" logger.info("Linear-interpolation algorithm is used for (h, h).") self.quadratic_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers) - original_duration = self.interferometers.duration + original_duration = float(self.interferometers.duration) for b in range(self.number_of_bands): logger.info(f"Pre-computing quadratic coefficients for the {b}-th band") @@ -576,7 +576,7 @@ def _setup_quadratic_coefficients_linear_interp(self): start_idx_in_band + len(window_sequence) - 1, len(ifo.power_spectral_density_array) - 1 ) - _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1] + _frequency_mask = np.asarray(ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1]) window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band) window_over_psd[_frequency_mask] = \ 1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask] From ca7e4f896164194edafaaf3ad6215f69d71ea13c Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 29 Jan 2025 13:14:26 -0800 Subject: [PATCH 017/110] make high-dimensional gaussians jax compatible --- bilby/core/likelihood.py | 50 +++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 0ba344ec5..8c43cad0d 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -8,6 +8,7 @@ from scipy.stats import multivariate_normal from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args, logger +from ..compat.utils import array_module PARAMETERS_AS_STATE = os.environ.get("BILBY_ALLOW_PARAMETERS_AS_STATE", "TRUE") @@ -553,11 +554,18 @@ class AnalyticalMultidimensionalCovariantGaussian(Likelihood): """ def __init__(self, mean, cov): - self.cov = np.atleast_2d(cov) - self.mean = np.atleast_1d(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) - super(AnalyticalMultidimensionalCovariantGaussian, self).__init__() + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.mean = xp.atleast_1d(mean) + self.sigma = xp.sqrt(np.diag(self.cov)) + if xp == np: + self.logpdf = multivariate_normal(mean=self.mean, cov=self.cov).logpdf + else: + from functools import partial + from jax.scipy.stats.multivariate_normal import logpdf + self.logpdf = partial(logpdf, mean=self.mean, cov=self.cov) + parameters = {"x{0}".format(i): 0 for i in range(self.dim)} + super(AnalyticalMultidimensionalCovariantGaussian, self).__init__(parameters=parameters) @property def dim(self): @@ -565,8 +573,9 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + xp = array_module(self.cov) + x = xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + return self.logpdf(x) class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): @@ -584,13 +593,21 @@ class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): """ def __init__(self, mean_1, mean_2, cov): - self.cov = np.atleast_2d(cov) - self.sigma = np.sqrt(np.diag(self.cov)) - self.mean_1 = np.atleast_1d(mean_1) - self.mean_2 = np.atleast_1d(mean_2) - self.pdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov) - self.pdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov) - super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__() + xp = array_module(cov) + self.cov = xp.atleast_2d(cov) + self.sigma = xp.sqrt(np.diag(self.cov)) + self.mean_1 = xp.atleast_1d(mean_1) + self.mean_2 = xp.atleast_1d(mean_2) + if xp == np: + self.logpdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov).logpdf + self.logpdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov).logpdf + else: + from functools import partial + from jax.scipy.stats.multivariate_normal import logpdf + self.logpdf_1 = partial(logpdf, mean=self.mean_1, cov=self.cov) + self.logpdf_2 = partial(logpdf, mean=self.mean_2, cov=self.cov) + parameters = {"x{0}".format(i): 0 for i in range(self.dim)} + super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__(parameters=parameters) @property def dim(self): @@ -598,8 +615,9 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) - return -np.log(2) + np.logaddexp(self.pdf_1.logpdf(x), self.pdf_2.logpdf(x)) + xp = array_module(self.cov) + x = xp.array([self.parameters["x{0}".format(i)] for i in range(self.dim)]) + return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x)) class JointLikelihood(Likelihood): From a8a9b98926a98ba2c0eab204b81de1449ef31485 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 30 Jan 2025 12:53:42 -0800 Subject: [PATCH 018/110] make cubic spline calibration work with jax backend --- bilby/gw/detector/calibration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 729b9e332..229bdba25 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -46,6 +46,7 @@ import pandas as pd from scipy.interpolate import interp1d +from ...compat.utils import array_module from ...core.utils.log import logger from ...core.prior.dict import PriorDict from ..prior import CalibrationPriorDict @@ -330,9 +331,11 @@ def __repr__(self): def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): """Evaluate Eq. (1) in https://dcc.ligo.org/LIGO-T2300140""" - parameters = np.array([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) + xp = array_module(self.params[f"{kind}_0"]) + parameters = xp.array([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) next_nodes = previous_nodes + 1 - spline_coefficients = self.nodes_to_spline_coefficients.dot(parameters) + nodes = xp.array(self.nodes_to_spline_coefficients) + spline_coefficients = nodes.dot(parameters) return ( a * parameters[previous_nodes] + b * parameters[next_nodes] From 2117df4a253d0cbf14a8931f5088a6be7cd7e98c Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 11:54:08 -0600 Subject: [PATCH 019/110] BUG: fix linspace calls --- bilby/core/utils/series.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index 8affa61be..c60362ab3 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -101,9 +101,11 @@ def create_time_series(sampling_frequency, duration, starting_time=0.): xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) number_of_samples = int(duration * sampling_frequency) - return xp.linspace(start=starting_time, - stop=duration + starting_time - 1 / sampling_frequency, - num=number_of_samples) + return xp.linspace( + starting_time, + duration + starting_time - 1 / sampling_frequency, + num=number_of_samples, + ) def create_frequency_series(sampling_frequency, duration): @@ -124,9 +126,7 @@ def create_frequency_series(sampling_frequency, duration): number_of_samples = int(xp.round(duration * sampling_frequency)) number_of_frequencies = int(xp.round(number_of_samples / 2) + 1) - return xp.linspace(start=0, - stop=sampling_frequency / 2, - num=number_of_frequencies) + return xp.linspace(0, sampling_frequency / 2, num=number_of_frequencies) def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): @@ -209,10 +209,11 @@ def nfft(time_domain_strain, sampling_frequency): strain / Hz, and the associated frequency_array. """ - frequency_domain_strain = np.fft.rfft(time_domain_strain) + xp = array_module(time_domain_strain) + frequency_domain_strain = xp.fft.rfft(time_domain_strain) frequency_domain_strain /= sampling_frequency - frequency_array = np.linspace( + frequency_array = xp.linspace( 0, sampling_frequency / 2, len(frequency_domain_strain)) return frequency_domain_strain, frequency_array @@ -234,7 +235,8 @@ def infft(frequency_domain_strain, sampling_frequency): time_domain_strain: array_like An array of the time domain strain """ - time_domain_strain_norm = np.fft.irfft(frequency_domain_strain) + xp = array_module(frequency_domain_strain) + time_domain_strain_norm = xp.fft.irfft(frequency_domain_strain) time_domain_strain = time_domain_strain_norm * sampling_frequency return time_domain_strain From 3e46a9b785633b8bd800cdd3758d6b0905fb79ce Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 14:56:33 -0600 Subject: [PATCH 020/110] ENH: fix bottleneck in relative binning for JAX --- bilby/gw/likelihood/relative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index 48a4cd585..d40015219 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -328,7 +328,7 @@ def compute_summary_data(self): masked_bin_inds[-1] += 1 masked_strain = interferometer.frequency_domain_strain[mask] - masked_h0 = self.per_detector_fiducial_waveforms[interferometer.name][mask] + masked_h0 = np.asarray(self.per_detector_fiducial_waveforms[interferometer.name][mask]) masked_psd = interferometer.power_spectral_density_array[mask] duration = interferometer.duration a0, b0, a1, b1 = np.zeros((4, self.number_of_bins), dtype=complex) From 1d891cc503623fb4a7c72138ba9fdcb8755ec35c Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 4 Feb 2025 14:58:15 -0600 Subject: [PATCH 021/110] ENH: make interpolated prior backend friendly --- bilby/core/prior/interpolated.py | 3 ++- bilby/core/utils/calculus.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index d47f14209..57e04738d 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -2,7 +2,8 @@ from scipy.integrate import trapezoid from .base import Prior -from ..utils import logger, WrappedInterp1d as interp1d +from ..utils import logger +from ..utils.calculus import interp1d from ...compat.utils import xp_wrap diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index e10ce6111..137dd894c 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,10 +1,11 @@ import math import numpy as np -from scipy.interpolate import interp1d, RectBivariateSpline +from scipy.interpolate import RectBivariateSpline, interp1d as _interp1d from scipy.special import logsumexp from .log import logger +from ...compat.utils import array_module def derivatives( @@ -189,6 +190,20 @@ def logtrapzexp(lnf, dx): return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)]) +class interp1d(_interp1d): + + def __call__(self, x): + xp = array_module(x) + if "jax" in xp.__name__: + if isinstance(self.fill_value, tuple): + left, right = self.fill_value + else: + left = right = self.fill_value + return xp.interp(x , xp.asarray(self.x), xp.asarray(self.y), left=left, right=right) + else: + return super().__call__(x) + + class BoundedRectBivariateSpline(RectBivariateSpline): def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): From a48544bc3b55923e42de6da156d9135c354e380d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 5 Feb 2025 12:14:25 -0600 Subject: [PATCH 022/110] REFACTOR: refactor backend-specific interpolation code --- bilby/core/utils/calculus.py | 51 +++++++++++++++++++++++++++++------- bilby/gw/likelihood/base.py | 19 +++++--------- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 137dd894c..bb89f48ee 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -192,16 +192,27 @@ def logtrapzexp(lnf, dx): class interp1d(_interp1d): - def __call__(self, x): - xp = array_module(x) - if "jax" in xp.__name__: - if isinstance(self.fill_value, tuple): - left, right = self.fill_value - else: - left = right = self.fill_value - return xp.interp(x , xp.asarray(self.x), xp.asarray(self.y), left=left, right=right) - else: - return super().__call__(x) + def __call__(self, x): + from array_api_compat import is_numpy_namespace + + xp = array_module(x) + if is_numpy_namespace(xp): + return super().__call__(x) + else: + return self._call_alt(x, xp=xp) + + def _call_alt(self, x, *, xp=np): + if isinstance(self.fill_value, tuple): + left, right = self.fill_value + else: + left = right = self.fill_value + return xp.interp( + x, + xp.asarray(self.x), + xp.asarray(self.y), + left=left, + right=right, + ) class BoundedRectBivariateSpline(RectBivariateSpline): @@ -217,9 +228,16 @@ def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): if self.y_max is None: self.y_max = max(y) self.fill_value = fill_value + self.x = x + self.y = y + self.z = z super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s) def __call__(self, x, y, dx=0, dy=0, grid=False): + from array_api_compat import is_jax_namespace + xp = array_module(x) + if is_jax_namespace(xp): + return self._call_jax(x, y) result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) @@ -232,6 +250,19 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): return result.item() else: return result + + def _call_jax(self, x, y): + import jax.numpy as jnp + from interpax import interp2d + + return interp2d( + x, + y, + jnp.asarray(self.x), + jnp.asarray(self.y), + jnp.asarray(self.z), + extrap=self.fill_value, + ) class WrappedInterp1d(interp1d): diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 3383110aa..365c01afa 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -933,19 +933,12 @@ def _setup_distance_marginalization(self, lookup_table=None): self._create_lookup_table() else: self._create_lookup_table() - if "jax" in array_module(self.interferometers[0].vertex).__name__: - from interpax import Interpolator2D - import jax.numpy as jnp - self._interp_dist_margd_loglikelihood = Interpolator2D( - jnp.asarray(self._d_inner_h_ref_array), - jnp.asarray(self._optimal_snr_squared_ref_array), - jnp.asarray(self._dist_margd_loglikelihood_array.T), - extrap=-jnp.inf, - ) - else: - self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( - self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, - self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) + self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( + self._d_inner_h_ref_array, + self._optimal_snr_squared_ref_array, + self._dist_margd_loglikelihood_array.T, + fill_value=-np.inf, + ) @property def cached_lookup_table_filename(self): From 15edfba2cdd8f84490b0e336ba1f0c14c37456b0 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 5 Feb 2025 12:35:04 -0600 Subject: [PATCH 023/110] ENH: make sine gaussian model backend independent --- bilby/gw/source.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bilby/gw/source.py b/bilby/gw/source.py index 951346760..b38d56436 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -1,5 +1,6 @@ import numpy as np +from ..compat.utils import array_module from ..core import utils from ..core.utils import logger from .conversion import bilby_to_lalsimulation_spins @@ -1187,20 +1188,22 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs): dict: Dictionary containing the plus and cross components of the strain. """ - tau = Q / (np.sqrt(2.0) * np.pi * frequency) - temp = Q / (4.0 * np.sqrt(np.pi) * frequency) + xp = array_module(frequency_array) + tau = Q / (2.0**0.5 * np.pi * frequency) + temp = Q / (4.0 * np.pi**0.5 * frequency) fm = frequency_array - frequency fp = frequency_array + frequency - h_plus = ((hrss / np.sqrt(temp * (1 + np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) + - np.exp(-fp**2 * np.pi**2 * tau**2))) + negative_term = xp.exp(-fm**2 * np.pi**2 * tau**2) + positive_term = xp.exp(-fp**2 * np.pi**2 * tau**2) - h_cross = (-1j * (hrss / np.sqrt(temp * (1 - np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) - - np.exp(-fp**2 * np.pi**2 * tau**2))) + h_plus = hrss * np.pi**0.5 * tau / 2 * ( + negative_term + positive_term + ) / (temp * (1 + xp.exp(-Q**2)))**0.5 + + h_cross = -1j * hrss * np.pi**0.5 * tau / 2 * ( + negative_term - positive_term + ) / (temp * (1 - np.exp(-Q**2)))**0.5 return {'plus': h_plus, 'cross': h_cross} From 5ce19b3be90eb6afa57e39a3ef354a90bc01b8df Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 5 Feb 2025 12:35:52 -0600 Subject: [PATCH 024/110] ENH: make roq likelihood backend independent --- bilby/gw/likelihood/roq.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 0f5a4c003..3e23137f6 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -2,6 +2,7 @@ import numpy as np from .base import GravitationalWaveTransient +from ...compat.utils import array_module from ...core.utils import ( logger, create_frequency_series, speed_of_light, radius_of_earth ) @@ -458,8 +459,8 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] size_linear = len(linear_indices) size_quadratic = len(quadratic_indices) - h_linear = np.zeros(size_linear, dtype=complex) - h_quadratic = np.zeros(size_quadratic, dtype=complex) + h_linear = 0j + h_quadratic = 0j for mode in waveform_polarizations['linear']: response = interferometer.antenna_response( parameters['ra'], parameters['dec'], @@ -470,13 +471,14 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr h_linear += waveform_polarizations['linear'][mode] * response h_quadratic += waveform_polarizations['quadratic'][mode] * response + xp = array_module(h_linear) calib_factor = interferometer.calibration_model.get_calibration_factor( frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **parameters) h_linear *= calib_factor[linear_indices] h_quadratic *= calib_factor[quadratic_indices] - optimal_snr_squared = np.vdot( - np.abs(h_quadratic)**2, + optimal_snr_squared = xp.vdot( + xp.abs(h_quadratic)**2, self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] ) @@ -487,20 +489,18 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr indices, in_bounds = self._closest_time_indices( ifo_time, self.weights['time_samples']) - if not in_bounds: - logger.debug("SNR calculation error: requested time at edge of ROQ time samples") - d_inner_h = -np.inf - complex_matched_filter_snr = -np.inf - else: - d_inner_h_tc_array = np.einsum( - 'i,ji->j', np.conjugate(h_linear), - self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + d_inner_h_tc_array = xp.einsum( + 'i,ji->j', xp.conjugate(h_linear), + self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + + d_inner_h = self._interp_five_samples( + self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) - d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + with np.errstate(invalid="ignore"): + complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) - with np.errstate(invalid="ignore"): - complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) + d_inner_h += xp.log(in_bounds) + complex_matched_filter_snr += xp.log(in_bounds) if return_array and self.time_marginalization: ifo_times = self._times - interferometer.strain_data.start_time @@ -537,10 +537,11 @@ def _closest_time_indices(time, samples): in_bounds: bool Whether the indices are for valid times """ - closest = int((time - samples[0]) / (samples[1] - samples[0])) + xp = array_module(time) + closest = xp.floor((time - samples[0]) / (samples[1] - samples[0])) indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) - return indices, in_bounds + return xp.asarray(indices).astype(int), in_bounds @staticmethod def _interp_five_samples(time_samples, values, time): From 8c7e992e3804288ae4bfca22583165095e7ffdc4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 5 Feb 2025 12:39:23 -0600 Subject: [PATCH 025/110] BUG: fix roq slicing --- bilby/gw/likelihood/roq.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 3e23137f6..6097557f7 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -490,11 +490,16 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr indices, in_bounds = self._closest_time_indices( ifo_time, self.weights['time_samples']) d_inner_h_tc_array = xp.einsum( - 'i,ji->j', xp.conjugate(h_linear), - self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + 'i,ji->j', + xp.conjugate(h_linear), + xp.asarray( + self.weights[interferometer.name + '_linear'][self.basis_number_linear] + )[indices], + ) d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + xp.asarray(self.weights['time_samples'])[indices], d_inner_h_tc_array, ifo_time + ) with np.errstate(invalid="ignore"): complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) From cf270d9bbed655f70a0b3d9ae08923ab2add74e2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Jun 2025 10:15:55 -0500 Subject: [PATCH 026/110] FEAT: make condition chi evaluable --- bilby/gw/prior.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index e262eaaf3..04c1d0db8 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -7,6 +7,7 @@ from scipy.special import hyp2f1 from scipy.stats import norm +from ..compat.utils import xp_wrap from ..core.prior import ( PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint, conditional_prior_factory, PowerLaw, ConditionalLogUniform, @@ -600,19 +601,23 @@ def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary self.__class__.__name__ = "ConditionalChiInPlane" self.__class__.__qualname__ = "ConditionalChiInPlane" - def prob(self, val, **required_variables): - self.update_conditions(**required_variables) + @xp_wrap + def prob(self, val, *, xp=np, **required_variables): + parameters = self.condition_func(self.reference_params.copy(), **required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) + minimum = parameters.get("minimum", self.minimum) + maximum = parameters.get("maximum", self.maximum) return ( - (val >= self.minimum) * (val <= self.maximum) + (val >= minimum) * (val <= maximum) * val / (chi_aligned ** 2 + val ** 2) - / np.log(self._reference_maximum / chi_aligned) + / xp.log(self._reference_maximum / chi_aligned) ) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=np, **required_variables): with np.errstate(divide="ignore"): - return np.log(self.prob(val, **required_variables)) + return xp.log(self.prob(val, **required_variables)) def cdf(self, val, **required_variables): r""" @@ -664,9 +669,9 @@ def rescale(self, val, **required_variables): def _condition_function(self, reference_params, **kwargs): with np.errstate(invalid="ignore"): - maximum = np.sqrt( + maximum = ( self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2 - ) + )**0.5 return dict(minimum=0, maximum=maximum) def __repr__(self): From f40e8458af3e3748845d01e2a25800ce150d3ca1 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 12 Jun 2025 12:22:58 -0500 Subject: [PATCH 027/110] MAINT: make whitening work for non-numpy --- bilby/gw/detector/interferometer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 437cf6abf..bf8a5a791 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -485,7 +485,7 @@ def inject_signal_from_waveform_polarizations(self, parameters, injection_polari self.strain_data.frequency_domain_strain += signal_ifo self.meta_data['optimal_SNR'] = ( - np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real) + self.optimal_snr_squared(signal=signal_ifo)).real ** 0.5 self.meta_data['matched_filter_SNR'] = ( self.matched_filter_snr(signal=signal_ifo)) self.meta_data['parameters'] = parameters @@ -671,7 +671,7 @@ def whiten_frequency_series(self, frequency_series : np.array) -> np.array: frequency_series : np.array The frequency series, whitened by the ASD """ - return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)) + return frequency_series / (self.amplitude_spectral_density_array * (self.duration / 4)**0.5) def get_whitened_time_series_from_whitened_frequency_series( self, @@ -702,14 +702,13 @@ def get_whitened_time_series_from_whitened_frequency_series( w = \\sqrt{N W} = \\sqrt{\\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})} """ - frequency_window_factor = ( - np.sum(self.frequency_mask) - / len(self.frequency_mask) - ) + xp = array_module(whitened_frequency_series) + + frequency_window_factor = self.frequency_mask.mean() whitened_time_series = ( - np.fft.irfft(whitened_frequency_series) - * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor + xp.fft.irfft(whitened_frequency_series) + * self.frequency_mask.sum()**0.5 / frequency_window_factor ) return whitened_time_series From 2f17eeeddbcef378aa60962f00a49cd105f3ad37 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 20 Aug 2025 06:35:42 -0700 Subject: [PATCH 028/110] EXAMPLE: update jax example --- .../injection_examples/jax_fast_tutorial.py | 274 ++++++------------ 1 file changed, 86 insertions(+), 188 deletions(-) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index b1b586ecf..6bfd87ea8 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -17,6 +17,7 @@ import bilby import bilby.gw.jaxstuff +import numpy as np import jax import jax.numpy as jnp from jax import random @@ -28,10 +29,58 @@ bilby.core.utils.setup_logger() # log_level="WARNING") -def main(use_jax, model): +def setup_prior(): + # Set up a PriorDict, which inherits from dict. + # By default we will sample all terms in the signal models. However, this will + # take a long time for the calculation, so for this example we will set almost + # all of the priors to be equall to their injected values. This implies the + # prior is a delta function at the true, injected value. In reality, the + # sampler implementation is smart enough to not sample any parameter that has + # a delta-function prior. + # The above list does *not* include mass_1, mass_2, theta_jn and luminosity + # distance, which means those are the parameters that will be included in the + # sampler. If we do nothing, then the default priors get used. + priors = bilby.gw.prior.BBHPriorDict() + del priors["mass_1"], priors["mass_2"] + priors["geocent_time"] = bilby.core.prior.Uniform(1126249642, 1126269642) + priors["luminosity_distance"].minimum = 1 + priors["luminosity_distance"].maximum = 500 + priors["chirp_mass"].minimum = 2.35 + priors["chirp_mass"].maximum = 2.45 + # priors["luminosity_distance"] = bilby.core.prior.PowerLaw(2.0, 10.0, 500.0) + # priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["delta_phase"] = priors.pop("phase") + # del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"] + # priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + # # del priors["a_1"], priors["a_2"] + # # priors["chi_1"] = bilby.core.prior.Uniform(-0.05, 0.05) + # # priors["chi_2"] = bilby.core.prior.Uniform(-0.05, 0.05) + # del priors["theta_jn"], priors["psi"], priors["delta_phase"] + # priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1) + # priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1) + return priors + + +def original_to_sampling_priors(priors, truth): + del priors["ra"], priors["dec"] + priors["zenith"] = bilby.core.prior.Cosine() + priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) + priors["L1_time"] = bilby.core.prior.Uniform(truth["geocent_time"] - 0.1, truth["geocent_time"] + 0.1) + + +def main(use_jax, model, idx): # Set the duration and sampling frequency of the data segment that we're # going to inject the signal into - duration = 4.0 + duration = 64.0 sampling_frequency = 2048.0 minimum_frequency = 20.0 if use_jax: @@ -40,35 +89,17 @@ def main(use_jax, model): minimum_frequency = jax.numpy.array(minimum_frequency) # Specify the output directory and the name of the simulation. - outdir = "outdir" - label = f"{model}_{'jax' if use_jax else 'numpy'}" + outdir = "pp-test-2" + label = f"{model}_{'jax' if use_jax else 'numpy'}_{idx}" # Set up a random seed for result reproducibility. This is optional! - bilby.core.utils.random.seed(88170235) + bilby.core.utils.random.seed(88170235 + idx * 1000) - # We are going to inject a binary black hole waveform. We first establish a - # dictionary of parameters that includes all of the different waveform - # parameters, including masses of the two black holes (mass_1, mass_2), - # spins of both black holes (a, tilt, phi), etc. - injection_parameters = dict( - mass_1=36.0, - mass_2=29.0, - a_1=0.4, - a_2=0.3, - tilt_1=0.5, - tilt_2=1.0, - phi_12=1.7, - phi_jl=0.3, - luminosity_distance=2000.0, - theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, - ) + priors = setup_prior() + injection_parameters = priors.sample() if model == "relbin": injection_parameters["fiducial"] = 1 + original_to_sampling_priors(priors, injection_parameters) # Fixed arguments passed into the source model waveform_arguments = dict( @@ -89,13 +120,14 @@ def main(use_jax, model): fdsm = bilby.gw.source.lal_binary_black_hole_relative_binning case _: fdsm = bilby.gw.source.lal_binary_black_hole + # fdsm = bilby.gw.source.sinegaussian # Create the waveform_generator using a LAL BinaryBlackHole source function waveform_generator = bilby.gw.WaveformGenerator( duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=fdsm, - parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, + # parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, waveform_arguments=waveform_arguments, use_cache=not use_jax, ) @@ -107,10 +139,11 @@ def main(use_jax, model): ifos.set_strain_data_from_power_spectral_densities( sampling_frequency=sampling_frequency, duration=duration, - start_time=injection_parameters["geocent_time"] - 2, + start_time=injection_parameters["geocent_time"] - duration + 2, ) ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters + waveform_generator=waveform_generator, parameters=injection_parameters, + raise_error=False, ) if use_jax: ifos.set_array_backend(jax.numpy) @@ -124,58 +157,6 @@ def main(use_jax, model): ) del waveform_generator.waveform_arguments["minimum_frequency"] - # Set up a PriorDict, which inherits from dict. - # By default we will sample all terms in the signal models. However, this will - # take a long time for the calculation, so for this example we will set almost - # all of the priors to be equall to their injected values. This implies the - # prior is a delta function at the true, injected value. In reality, the - # sampler implementation is smart enough to not sample any parameter that has - # a delta-function prior. - # The above list does *not* include mass_1, mass_2, theta_jn and luminosity - # distance, which means those are the parameters that will be included in the - # sampler. If we do nothing, then the default priors get used. - priors = bilby.gw.prior.BBHPriorDict() - for key in [ - # "a_1", - # "a_2", - # "tilt_1", - # "tilt_2", - # "phi_12", - # "phi_jl", - # "psi", - # "ra", - # "dec", - # "geocent_time", - ]: - priors[key] = injection_parameters[key] - del priors["mass_1"], priors["mass_2"] - priors["L1_time"] = bilby.core.prior.Uniform(1126259642.41, 1126259642.45) - del priors["ra"], priors["dec"] - # priors["zenith"] = bilby.core.prior.Cosine() - # priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) - priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["delta_phase"] = priors.pop("phase") - priors["chirp_mass"].minimum = 20 - priors["chirp_mass"].maximum = 35 - del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"] - priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - del priors["theta_jn"], priors["psi"], priors["delta_phase"] - priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - - # Perform a check that the prior does not extend to a parameter space longer than the data - if not use_jax: - priors.validate_prior(duration, minimum_frequency) - # Initialise the likelihood by passing in the interferometer data (ifos) and # the waveform generator match model: @@ -191,130 +172,47 @@ def main(use_jax, model): interferometers=ifos, waveform_generator=waveform_generator, priors=priors, - # phase_marginalization=True, + phase_marginalization=True, + distance_marginalization=True, reference_frame=ifos, time_reference="L1", + # epsilon=0.1, + # update_fiducial_parameters=True, ) - if use_jax: - - def sample(): - parameters = priors.sample() - parameters = {key: jax.numpy.array(val) for key, val in parameters.items()} - return parameters - - # burn a few likelihood calls to check that we don't get - # repeated compilation - likelihood.parameters.update(sample()) - likelihood.log_likelihood_ratio() - likelihood.log_likelihood() - likelihood.noise_log_likelihood() - - with jax.log_compiles(): - jit_likelihood = bilby.gw.jaxstuff.JittedLikelihood( - likelihood, - cast_to_float=False, - ) - jit_likelihood.parameters.update(sample()) - jit_likelihood.log_likelihood_ratio() - jit_likelihood.log_likelihood() - jit_likelihood.noise_log_likelihood() - jit_likelihood.parameters.update(sample()) - jit_likelihood.log_likelihood_ratio() - jit_likelihood.log_likelihood() - jit_likelihood.noise_log_likelihood() - sample_likelihood = jit_likelihood - else: - sample_likelihood = likelihood - - def likelihood_func(parameters): - return sample_likelihood.likelihood_func(parameters, **sample_likelihood.kwargs) - - # import IPython; IPython.embed() - # raise SystemExit() - # use the log_compiles context so we can make sure there aren't recompilations # inside the sampling loop - with jax.log_compiles(): + if True: + # with jax.log_compiles(): result = bilby.run_sampler( - likelihood=sample_likelihood, + likelihood=likelihood, priors=priors, - sampler="dynesty", - # sampler="numpyro", - sampler_name="ESS", - # sampler_name="NUTS", - num_warmup=500, - num_samples=500, - num_chains=100, - thinning=5, - # moves={ - # AIES.DEMove(): 0.35, - # ModeHopping(): 0.3, - # AIES.StretchMove(): 0.35, - # }, - moves={ - ESS.DifferentialMove(): 0.25, - ESS.KDEMove(): 0.25, - ESS.GaussianMove(): 0.5, - }, - chain_method="vectorized", - npoints=500, - # sample="acceptance-walk", - sample="act-walk", - naccept=10, + sampler="jaxted" if use_jax else "dynesty", + nlive=1000, + sample="acceptance-walk", + method="nest", + nsteps=100, + naccept=30, injection_parameters=injection_parameters, outdir=outdir, label=label, - npool=4, + npool=None if use_jax else 16, + # save="hdf5", + save=False, + rseed=np.random.randint(0, 100000), ) - # print(result) - # print(f"Sampling time: {result.sampling_time:.1f}s\n") # Make a corner plot. - result.plot_corner() - raise SystemExit() + # result.plot_corner() + import IPython; IPython.embed() return result.sampling_time -def ModeHopping(): - """ - A proposal using differential evolution. - - This `Differential evolution proposal - `_ is - implemented following `Nelson et al. (2013) - `_. - - :param sigma: (optional) - The standard deviation of the Gaussian used to stretch the proposal vector. - Defaults to `1.0.e-5`. - :param g0 (optional): - The mean stretch factor for the proposal vector. By default, - it is `2.38 / sqrt(2*ndim)` as recommended by the two references. - """ - - def make_de_move(n_chains): - PAIRS = get_nondiagonal_indices(n_chains // 2) - - def de_move(rng_key, active, inactive): - n_active_chains, _ = inactive.shape - - selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,)) - - # Compute diff vectors - diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1) - - proposal = active + diffs - - return proposal, jnp.zeros(n_active_chains) - - return de_move - - return make_de_move - - if __name__ == "__main__": times = dict() - for arg in product([True, False][1:], ["relbin", "mb", "regular"][1:2]): - times[arg] = main(*arg) + # for arg in product([True, False][:], ["relbin", "mb", "regular"][2:3]): + # times[arg] = main(*arg) + with jax.log_compiles(): + for idx in np.arange(100): + times[idx] = main(True, "mb", idx) print(times) From ee259594a92ac9277e80e202fee9936c533c77ac Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 20 Aug 2025 09:48:20 -0700 Subject: [PATCH 029/110] BUG: fix interpax interpolation method --- bilby/core/utils/calculus.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index bb89f48ee..f20973f4e 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -262,6 +262,7 @@ def _call_jax(self, x, y): jnp.asarray(self.y), jnp.asarray(self.z), extrap=self.fill_value, + method="cubic2", ) From 21d2306c56b0d50ccaca195114c7eea454b405f0 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 2 Oct 2025 16:06:10 +0000 Subject: [PATCH 030/110] REFACTOR: update variable backend for new parameter method --- bilby/compat/jax.py | 8 ++------ bilby/core/likelihood.py | 2 +- bilby/gw/detector/interferometer.py | 1 + bilby/gw/likelihood/base.py | 1 + requirements.txt | 2 ++ 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py index 8d297487b..7c35aa6f3 100644 --- a/bilby/compat/jax.py +++ b/bilby/compat/jax.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from ..core.likelihood import Likelihood +from ..core.likelihood import Likelihood, _safe_likelihood_call def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): @@ -20,11 +20,7 @@ def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): Default is :code:`True`. """ parameters = {k: jnp.array(v) for k, v in parameters.items()} - likelihood.parameters.update(parameters) - if use_ratio: - return likelihood.log_likelihood_ratio() - else: - return likelihood.log_likelihood() + return _safe_likelihood_call(likelihood, parameters, use_ratio) class JittedLikelihood(Likelihood): diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 8c43cad0d..0f0e50b8b 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -167,7 +167,7 @@ class ZeroLikelihood(Likelihood): def __init__(self, likelihood): super(ZeroLikelihood, self).__init__() - self.parameters = likelihood.parameters + self.parameters = dict() self._parent = likelihood def log_likelihood(self, parameters=None): diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index bf8a5a791..9a8584ddf 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -4,6 +4,7 @@ from ...core import utils from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump +from ...core.utils.env import string_to_boolean from ...compat.utils import array_module from .. import utils as gwutils from ..geometry import ( diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 365c01afa..53caf0fcb 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -411,6 +411,7 @@ def noise_log_likelihood(self): self._noise_log_likelihood_value = self._calculate_noise_log_likelihood() return self._noise_log_likelihood_value + def log_likelihood_ratio(self, parameters=None): if parameters is not None: parameters = copy.deepcopy(parameters) else: diff --git a/requirements.txt b/requirements.txt index b045db212..2b3cc4405 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ tqdm h5py attrs importlib-metadata>=3.6; python_version < '3.10' +plum-dispatch +array_api_compat From 3012c99ca26725f10c694cb5a816b775a3ad33ac Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 2 Oct 2025 19:38:44 +0000 Subject: [PATCH 031/110] some simplifications of array transparency --- bilby/compat/jax.py | 58 +++++++++-------------------- bilby/compat/utils.py | 2 +- bilby/core/likelihood.py | 40 ++++++++++++-------- bilby/core/prior/analytical.py | 48 ++++++++++++++---------- bilby/gw/detector/geometry.py | 10 +++++ bilby/gw/detector/interferometer.py | 18 +++------ bilby/gw/detector/networks.py | 4 ++ bilby/gw/likelihood/base.py | 12 +++--- bilby/gw/likelihood/basic.py | 18 +++++---- bilby/gw/sampler/proposal.py | 2 +- bilby/gw/source.py | 7 ++-- 11 files changed, 112 insertions(+), 107 deletions(-) diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py index 7c35aa6f3..f94f64cd8 100644 --- a/bilby/compat/jax.py +++ b/bilby/compat/jax.py @@ -1,26 +1,6 @@ -from functools import partial - import jax import jax.numpy as jnp -from ..core.likelihood import Likelihood, _safe_likelihood_call - - -def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True): - """ - A wrapper to allow a :code:`Bilby` likelihood to be used with :code:`jax`. - - Parameters - ========== - likelihood: bilby.core.likelihood.Likelihood - The likelihood to evaluate. - parameters: dict - The parameters to evaluate the likelihood at. - use_ratio: bool, optional - Whether to evaluate the likelihood ratio or the full likelihood. - Default is :code:`True`. - """ - parameters = {k: jnp.array(v) for k, v in parameters.items()} - return _safe_likelihood_call(likelihood, parameters, use_ratio) +from ..core.likelihood import Likelihood class JittedLikelihood(Likelihood): @@ -36,34 +16,30 @@ class JittedLikelihood(Likelihood): ========== likelihood: bilby.core.likelihood.Likelihood The likelihood to wrap. - likelihood_func: callable, optional - The function to use to evaluate the likelihood. Default is - :code:`generic_bilby_likelihood_function`. This function should take the - likelihood and parameters as arguments along with additional keyword arguments. - kwargs: dict, optional - Additional keyword arguments to pass to the likelihood function. + cast_to_float: bool + Whether to return a float instead of a :code:`jax.Array`. """ - def __init__( - self, - likelihood, - likelihood_func=generic_bilby_likelihood_function, - kwargs=None, - cast_to_float=True, - ): - if kwargs is None: - kwargs = dict() - self.kwargs = kwargs + def __init__(self, likelihood, cast_to_float=True): self._likelihood = likelihood - self.likelihood_func = jax.jit(partial(likelihood_func, likelihood)) + self._ll = jax.jit(likelihood.log_likelihood) + self._llr = jax.jit(likelihood.log_likelihood_ratio) self.cast_to_float = cast_to_float - super().__init__(dict()) + super().__init__() def __getattr__(self, name): return getattr(self._likelihood, name) - def log_likelihood_ratio(self): - ln_l = jnp.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs)) + def log_likelihood(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._ll(parameters) + if self.cast_to_float: + ln_l = float(ln_l) + return ln_l + + def log_likelihood_ratio(self, parameters): + parameters = {k: jnp.array(v) for k, v in parameters.items()} + ln_l = self._llr(parameters) if self.cast_to_float: ln_l = float(ln_l) return ln_l diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index bd97f828e..b19473e92 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -1,5 +1,5 @@ import numpy as np -from scipy._lib._array_api import array_namespace +from array_api_compat import array_namespace __all__ = ["array_module", "promote_to_array"] diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 0f0e50b8b..373a827d0 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -4,6 +4,7 @@ from warnings import warn import numpy as np +from array_api_compat import is_array_api_obj from scipy.special import gammaln, xlogy from scipy.stats import multivariate_normal @@ -283,9 +284,10 @@ def __init__(self, x, y, func, sigma=None, **kwargs): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) + xp = array_module(self.x) sigma = parameters.get("sigma", self.sigma) - log_l = np.sum(- (self.residual(parameters) / sigma)**2 / 2 - - np.log(2 * np.pi * sigma**2) / 2) + log_l = xp.sum(- (self.residual(parameters) / sigma)**2 / 2 - + xp.log(2 * np.pi * sigma**2) / 2) return log_l def __repr__(self): @@ -344,17 +346,18 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): rate = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if not isinstance(rate, np.ndarray): + if not is_array_api_obj(rate): raise ValueError( "Poisson rate function returns wrong value type! " "Is {} when it should be numpy.ndarray".format(type(rate))) - elif np.any(rate < 0.): + elif any(rate < 0.): raise ValueError(("Poisson rate function returns a negative", " value!")) - elif np.any(rate == 0.): + elif any(rate == 0.): return -np.inf else: - return np.sum(-rate + self.y * np.log(rate) - gammaln(self.y + 1)) + xp = array_module(rate) + return xp.sum(-rate + self.y * xp.log(rate) - gammaln(self.y + 1)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -395,9 +398,10 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if np.any(mu < 0.): + if any(mu < 0.): return -np.inf - return -np.sum(np.log(mu) + (self.y / mu)) + xp = array_module(mu) + return -xp.sum(xp.log(mu) + (self.y / mu)) def __repr__(self): return Analytical1DLikelihood.__repr__(self) @@ -411,7 +415,7 @@ def y(self): def y(self, y): if not isinstance(y, np.ndarray): y = np.array([y]) - if np.any(y < 0): + if any(y < 0): raise ValueError("Data must be non-negative") self._y = y @@ -458,9 +462,10 @@ def log_likelihood(self, parameters=None): raise ValueError("Number of degrees of freedom for Student's " "t-likelihood must be positive") + xp = array_module(self.x) log_l =\ - np.sum(- (nu + 1) * np.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + - np.log(self.lam / (nu * np.pi)) / 2 + + xp.sum(- (nu + 1) * xp.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + + xp.log(self.lam / (nu * np.pi)) / 2 + gammaln((nu + 1) / 2) - gammaln(nu / 2)) return log_l @@ -507,8 +512,10 @@ def __init__(self, data, n_dimensions, base="parameter_"): base: str The base of the parameter labels """ - self.data = np.array(data) - self._total = np.sum(self.data) + if not is_array_api_obj(data): + data = np.array(data) + self.data = data + self._total = self.data.sum() super(Multinomial, self).__init__() self.n = n_dimensions self.base = base @@ -535,7 +542,8 @@ def noise_log_likelihood(self): def _multinomial_ln_pdf(self, probs): """Lifted from scipy.stats.multinomial._logpdf""" - ln_prob = gammaln(self._total + 1) + np.sum( + xp = array_module(self.data) + ln_prob = gammaln(self._total + 1) + xp.sum( xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1) return ln_prob @@ -557,7 +565,7 @@ def __init__(self, mean, cov): xp = array_module(cov) self.cov = xp.atleast_2d(cov) self.mean = xp.atleast_1d(mean) - self.sigma = xp.sqrt(np.diag(self.cov)) + self.sigma = xp.sqrt(xp.diag(self.cov)) if xp == np: self.logpdf = multivariate_normal(mean=self.mean, cov=self.cov).logpdf else: @@ -595,7 +603,7 @@ class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): def __init__(self, mean_1, mean_2, cov): xp = array_module(cov) self.cov = xp.atleast_2d(cov) - self.sigma = xp.sqrt(np.diag(self.cov)) + self.sigma = xp.sqrt(xp.diag(self.cov)) self.mean_1 = xp.atleast_1d(mean_1) self.mean_2 = xp.atleast_1d(mean_2) if xp == np: diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 876196ec6..755fd3f41 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -17,7 +17,7 @@ from .base import Prior from ..utils import logger -from ...compat.utils import xp_wrap +from ...compat.utils import array_module, xp_wrap class DeltaFunction(Prior): @@ -363,7 +363,7 @@ def ln_prob(self, val, *, xp=np): float: """ - return np.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(self.maximum / self.minimum))) + return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(self.maximum / self.minimum))) @xp_wrap def cdf(self, val, *, xp=np): @@ -1055,7 +1055,8 @@ def rescale(self, val, *, xp=np): val = xp.asarray(val) return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -1066,7 +1067,7 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val)) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -1082,10 +1083,11 @@ def ln_prob(self, val, *, xp=np): """ with np.errstate(over="ignore"): return -(val - self.mu) / self.scale -\ - 2. * np.log1p(xp.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) + 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(self.scale) - def cdf(self, val): - return 1. / (1. + np.exp(-(val - self.mu) / self.scale)) + @xp_wrap + def cdf(self, val, *, xp=np): + return 1. / (1. + xp.exp(-(val - self.mu) / self.scale)) class Cauchy(Prior): @@ -1141,7 +1143,8 @@ def prob(self, val): """ return 1. / self.beta / np.pi / (1. + ((val - self.alpha) / self.beta) ** 2) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the log prior probability of val. Parameters @@ -1152,10 +1155,11 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return - np.log(self.beta * np.pi) - np.log(1. + ((val - self.alpha) / self.beta) ** 2) + return - xp.log(self.beta * np.pi) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) - def cdf(self, val): - return 0.5 + np.arctan((val - self.alpha) / self.beta) / np.pi + @xp_wrap + def cdf(self, val, *, xp=np): + return 0.5 + xp.arctan((val - self.alpha) / self.beta) / np.pi class Lorentzian(Cauchy): @@ -1323,9 +1327,11 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") - self.expr = np.exp(self.r) + xp = array_module(np) + self.expr = xp.exp(self.r) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1343,7 +1349,7 @@ def rescale(self, val): `_, 2017. """ inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr - return -self.sigma * np.log(np.maximum(inv, 0)) + return -self.sigma * xp.log(xp.maximum(inv, 0)) @xp_wrap def prob(self, val, *, xp=np): @@ -1363,7 +1369,8 @@ def prob(self, val, *, xp=np): * (val >= self.minimum) ) - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=np): """Return the log prior probability of val. Parameters @@ -1374,9 +1381,10 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return np.log(self.prob(val)) + return xp.log(self.prob(val)) - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): """ Evaluate the CDF of the Fermi-Dirac distribution using a slightly modified form of Equation 23 of [1]_. @@ -1398,10 +1406,10 @@ def cdf(self, val): `_, 2017. """ result = ( - (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) - / np.logaddexp(0, self.r) + (xp.logaddexp(0, -self.r) - xp.logaddexp(-val / self.sigma, -self.r)) + / xp.logaddexp(0, self.r) ) - return np.clip(result, 0, 1) + return xp.clip(result, 0, 1) class WeightedDiscreteValues(Prior): diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index 627f6a143..5d0de9b9f 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -304,3 +304,13 @@ def unit_vector_along_arm(self, arm): ) else: raise ValueError("Arm must either be 'x' or 'y'.") + + def set_array_backend(self, xp): + self.length = xp.array(self.length) + self.latitude = xp.array(self.latitude) + self.longitude = xp.array(self.longitude) + self.elevation = xp.array(self.elevation) + self.xarm_azimuth = xp.array(self.xarm_azimuth) + self.yarm_azimuth = xp.array(self.yarm_azimuth) + self.xarm_tilt = xp.array(self.xarm_tilt) + self.yarm_tilt = xp.array(self.yarm_tilt) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 9a8584ddf..937a15d0e 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -3,7 +3,7 @@ import numpy as np from ...core import utils -from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump +from ...core.utils import PropertyAccessor, docstring, logger, safe_file_dump from ...core.utils.env import string_to_boolean from ...compat.utils import array_module from .. import utils as gwutils @@ -929,14 +929,8 @@ def from_pickle(cls, filename=None): return res def set_array_backend(self, xp): - for attr in [ - "length", - "latitude", - "longitude", - "elevation", - "xarm_azimuth", - "yarm_azimuth", - "xarm_tilt", - "yarm_tilt", - ]: - setattr(self, attr, xp.array(getattr(self, attr))) + self.geometry.set_array_backend(xp=xp) + + @property + def array_backend(self): + return array_module(self.geometry.length) diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 988c1b76e..58e0cd856 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -336,6 +336,10 @@ def set_array_backend(self, xp): for ifo in self: ifo.set_array_backend(xp) + @property + def array_backend(self): + return self[0].array_backend + class TriangularInterferometer(InterferometerList): def __init__( diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 53caf0fcb..d5459bf48 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -6,7 +6,6 @@ import numpy as np from scipy.special import logsumexp -from ...compat.utils import array_module from ...core.likelihood import Likelihood, _fallback_to_parameters from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series from ...core.prior import Interped, Prior, Uniform, DeltaFunction @@ -158,6 +157,7 @@ def __init__( self.waveform_generator = waveform_generator super(GravitationalWaveTransient, self).__init__() self.interferometers = InterferometerList(interferometers) + self.interferometers.set_array_backend(interferometers.array_backend) self.time_marginalization = time_marginalization self.distance_marginalization = distance_marginalization self.phase_marginalization = phase_marginalization @@ -170,7 +170,7 @@ def __init__( if "geocent" not in time_reference: self.time_reference = time_reference self.reference_ifo = get_empty_interferometer(self.time_reference) - self.reference_ifo.set_array_backend(array_module(self.interferometers[0].vertex)) + self.reference_ifo.set_array_backend(self.interferometers.array_backend) if self.time_marginalization: logger.info("Cannot marginalise over non-geocenter time.") self.time_marginalization = False @@ -398,12 +398,12 @@ def _calculate_noise_log_likelihood(self): log_l = 0 for interferometer in self.interferometers: mask = interferometer.frequency_mask - log_l -= noise_weighted_inner_product( + log_l -= abs(noise_weighted_inner_product( interferometer.frequency_domain_strain[mask], interferometer.frequency_domain_strain[mask], interferometer.power_spectral_density_array[mask], - self.waveform_generator.duration) / 2 - return float(np.real(log_l)) + self.waveform_generator.duration) / 2) + return log_l def noise_log_likelihood(self): # only compute likelihood if called for the 1st time @@ -1093,6 +1093,8 @@ def reference_frame(self, frame): self._reference_frame = InterferometerList([frame[:2], frame[2:4]]) else: raise ValueError("Unable to parse reference frame {}".format(frame)) + if isinstance(self._reference_frame, InterferometerList): + self._reference_frame.set_array_backend(self.interferometers.array_backend) def get_sky_frame_parameters(self, parameters=None): """ diff --git a/bilby/gw/likelihood/basic.py b/bilby/gw/likelihood/basic.py index b2f04eb69..c4e04987b 100644 --- a/bilby/gw/likelihood/basic.py +++ b/bilby/gw/likelihood/basic.py @@ -43,10 +43,11 @@ def noise_log_likelihood(self): """ log_l = 0 for interferometer in self.interferometers: - log_l -= 2. / self.waveform_generator.duration * np.sum( - abs(interferometer.frequency_domain_strain) ** 2 / - interferometer.power_spectral_density_array) - return log_l.real + log_l -= 2. / self.waveform_generator.duration * ( + abs(interferometer.frequency_domain_strain) ** 2 + / interferometer.power_spectral_density_array + ).sum() + return log_l def log_likelihood(self, parameters=None): """ Calculates the real part of log-likelihood value @@ -87,8 +88,9 @@ def log_likelihood_interferometer(self, waveform_polarizations, signal_ifo = interferometer.get_detector_response( waveform_polarizations, parameters) - log_l = - 2. / self.waveform_generator.duration * np.vdot( - interferometer.frequency_domain_strain - signal_ifo, - (interferometer.frequency_domain_strain - signal_ifo) / - interferometer.power_spectral_density_array) + residual = interferometer.frequency_domain_strain - signal_ifo + + log_l = - 2. / self.waveform_generator.duration * ( + abs(residual)**2 / interferometer.power_spectral_density_array + ).sum() return log_l.real diff --git a/bilby/gw/sampler/proposal.py b/bilby/gw/sampler/proposal.py index 79e1ec92c..2ac84687e 100644 --- a/bilby/gw/sampler/proposal.py +++ b/bilby/gw/sampler/proposal.py @@ -13,7 +13,7 @@ class SkyLocationWanderJump(JumpProposal): def __call__(self, sample, **kwargs): temperature = 1 / kwargs.get('inverse_temperature', 1.0) - sigma = np.sqrt(temperature) / 2 / np.pi + sigma = temperature**0.5 / 2 / np.pi sample['ra'] += random.gauss(0, sigma) sample['dec'] += random.gauss(0, sigma) return super(SkyLocationWanderJump, self).__call__(sample) diff --git a/bilby/gw/source.py b/bilby/gw/source.py index b38d56436..96973efd9 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -1203,7 +1203,7 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs): h_cross = -1j * hrss * np.pi**0.5 * tau / 2 * ( negative_term - positive_term - ) / (temp * (1 - np.exp(-Q**2)))**0.5 + ) / (temp * (1 - xp.exp(-Q**2)))**0.5 return {'plus': h_plus, 'cross': h_cross} @@ -1286,12 +1286,13 @@ def supernova_pca_model( dict: The plus and cross polarizations of the signal """ + xp = array_module(frequency_array) principal_components = kwargs["realPCs"] + 1j * kwargs["imagPCs"] coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5] - strain = np.sum( - [coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)], + strain = xp.sum( + xp.array([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)]), axis=0 ) From 913428e8ddd352e8ac1a6c6b458bd18423504305 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 2 Oct 2025 19:39:11 +0000 Subject: [PATCH 032/110] HYPER: make hyperparameter likelihood handle array backends --- bilby/hyper/likelihood.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 4b691845e..2c4b65abd 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -3,6 +3,7 @@ import numpy as np +from ..compat.utils import array_module from ..core.likelihood import Likelihood, _fallback_to_parameters from .model import Model from ..core.prior import PriorDict @@ -29,11 +30,13 @@ class HyperparameterLikelihood(Likelihood): the sampling prior and the hyperparameterised model. max_samples: int, optional Maximum number of samples to use from each set. + xp: module + The array backend to use for the data. """ def __init__(self, posteriors, hyper_prior, sampling_prior=None, - log_evidences=None, max_samples=1e100): + log_evidences=None, max_samples=1e100, xp=np): if not isinstance(hyper_prior, Model): hyper_prior = Model([hyper_prior]) if sampling_prior is None: @@ -53,7 +56,7 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, self.max_samples = max_samples super(HyperparameterLikelihood, self).__init__() - self.data = self.resample_posteriors() + self.data = self.resample_posteriors(xp=xp) self.n_posteriors = len(self.posteriors) self.samples_per_posterior = self.max_samples self.samples_factor =\ @@ -61,10 +64,11 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, def log_likelihood_ratio(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) / - self.data['prior'], axis=-1))) + probs = self.hyper_prior.prob(self.data, **parameters) + xp = array_module(probs) + log_l = xp.sum(xp.log(xp.sum(probs / self.data['prior'], axis=-1))) log_l += self.samples_factor - return np.nan_to_num(log_l) + return xp.nan_to_num(log_l) def noise_log_likelihood(self): return self.evidence_factor @@ -72,7 +76,7 @@ def noise_log_likelihood(self): def log_likelihood(self, parameters=None): return self.noise_log_likelihood() + self.log_likelihood_ratio(parameters=parameters) - def resample_posteriors(self, max_samples=None): + def resample_posteriors(self, max_samples=None, xp=np): """ Convert list of pandas DataFrame object to dict of arrays. @@ -107,5 +111,5 @@ def resample_posteriors(self, max_samples=None): for key in data: data[key].append(temp[key]) for key in data: - data[key] = np.array(data[key]) + data[key] = xp.array(data[key]) return data From 2fa175268bc4e67cd8d0c556dac61b0e7f7999a6 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 11 Dec 2025 15:46:40 +0000 Subject: [PATCH 033/110] MAINT: switch back to bilby_cython --- bilby/gw/geometry.py | 2 +- bilby/gw/time.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index 60d55f8be..8b0cf2855 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -1,6 +1,6 @@ import numpy as np from plum import dispatch -from bilby_rust import geometry as _geometry +from bilby_cython import geometry as _geometry from .time import greenwich_mean_sidereal_time from ..compat.types import Real, ArrayLike diff --git a/bilby/gw/time.py b/bilby/gw/time.py index cb56b3bab..29a2b9c77 100644 --- a/bilby/gw/time.py +++ b/bilby/gw/time.py @@ -2,7 +2,7 @@ import numpy as np from plum import dispatch -from bilby_rust import time as _time +from bilby_cython import time as _time from ..compat.types import Real, ArrayLike from ..compat.utils import array_module From 88241967867518a8540a4bd85a906d7ba538e2a9 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:12:50 -0500 Subject: [PATCH 034/110] TYPO: fix typo in multiband time-marginalized likelihood --- bilby/gw/likelihood/multiband.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index d7ffabfde..1ece29154 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -761,7 +761,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ) if self.time_marginalization: - parameters["geocent_time"] = origianl_time + parameters["geocent_time"] = original_time d_inner_h = (strain @ self.linear_coeffs[interferometer.name]).conjugate() @@ -800,7 +800,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr start_idx, end_idx = self.start_end_idxs[b] self._full_d_h[self._full_to_multiband[start_idx:end_idx + 1]] += \ strain[start_idx:end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx:end_idx + 1] - d_inner_h_array = np.fft.fft(self._full_d_h) + d_inner_h_array = xp.fft.fft(self._full_d_h) else: d_inner_h_array = None From dfb325624fde1ee847b321fcc54e922ed5d0e1b5 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:13:05 -0500 Subject: [PATCH 035/110] MAINT: removed unused import --- bilby/gw/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index f2a7499e4..07a9e9aa1 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -5,7 +5,6 @@ import numpy as np from scipy.interpolate import interp1d from scipy.special import i0e -from plum import dispatch from .geometry import zenith_azimuth_to_theta_phi from .time import greenwich_mean_sidereal_time From 1c03740068940aa86992a4766893be66ed9f78aa Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:13:24 -0500 Subject: [PATCH 036/110] BUG: add explicit array cast in conversion --- bilby/gw/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index ca47dec57..2fbbfdd5c 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -287,7 +287,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters['a_{}'.format(idx)] = abs( converted_parameters[key]) converted_parameters['cos_tilt_{}'.format(idx)] = \ - xp.sign(converted_parameters[key]) + xp.sign(xp.asarray(converted_parameters[key])) else: with np.errstate(invalid="raise"): try: From 0f237d67ae90eebd1d7b7d800916c04f337cc85e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:13:56 -0500 Subject: [PATCH 037/110] REFACTOR: some refactoring of array edge cases --- bilby/compat/patches.py | 37 ++++++++++++++++++++++++++++++++++ bilby/compat/utils.py | 21 ++++++++++++++++--- bilby/core/likelihood.py | 32 ++++++++++++++--------------- bilby/core/prior/analytical.py | 29 ++++++++++++++------------ bilby/core/prior/dict.py | 12 +++++------ bilby/core/utils/calculus.py | 12 +++++++++-- 6 files changed, 103 insertions(+), 40 deletions(-) create mode 100644 bilby/compat/patches.py diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py new file mode 100644 index 000000000..02cbc3394 --- /dev/null +++ b/bilby/compat/patches.py @@ -0,0 +1,37 @@ +import array_api_compat as aac + +from .utils import BackendNotImplementedError + + +def erfinv_import(xp): + if aac.is_numpy_namespace(xp): + from scipy.special import erfinv + elif aac.is_jax_namespace(xp): + from jax.scipy.special import erfinv + elif aac.is_torch_namespace(xp): + from torch.special import erfinv + elif aac.is_cupy_namespace(xp): + from cupyx.scipy.special import erfinv + else: + raise BackendNotImplementedError + return erfinv + + +def multivariate_logpdf(xp, mean, cov): + if aac.is_numpy_namespace(xp): + from scipy.stats import multivariate_normal + + logpdf = multivariate_normal(mean=mean, cov=cov).logpdf + elif aac.is_jax_namespace(xp): + from functools import partial + from jax.scipy.stats.multivariate_normal import logpdf + + logpdf = partial(logpdf, mean=mean, cov=cov) + elif aac.is_torch_namespace(xp): + from torch.distributions.multivariate_normal import MultivariateNormal + + mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.array(cov)) + logpdf = mvn.log_prob + else: + raise BackendNotImplementedError + return logpdf diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index b19473e92..9f83ac9e5 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -1,14 +1,25 @@ import numpy as np from array_api_compat import array_namespace +from ..core.utils.log import logger + __all__ = ["array_module", "promote_to_array"] def array_module(arr): - if arr.__class__.__module__ == "builtins": - return np - else: + try: return array_namespace(arr) + except TypeError: + if arr.__class__.__module__ == "builtins": + return np + elif arr.__module__.startswith("pandas"): + return np + else: + logger.warning( + f"Unknown array module for type: {type(arr)} Defaulting to numpy." + ) + return np + def promote_to_array(args, backend, skip=None): @@ -32,3 +43,7 @@ def wrapped(self, *args, **kwargs): return func(self, *args, **kwargs) return wrapped + + +class BackendNotImplementedError(NotImplementedError): + pass diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 373a827d0..611885e16 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -6,10 +6,10 @@ import numpy as np from array_api_compat import is_array_api_obj from scipy.special import gammaln, xlogy -from scipy.stats import multivariate_normal from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args, logger -from ..compat.utils import array_module +from ..compat.patches import multivariate_logpdf +from ..compat.utils import BackendNotImplementedError, array_module PARAMETERS_AS_STATE = os.environ.get("BILBY_ALLOW_PARAMETERS_AS_STATE", "TRUE") @@ -566,12 +566,13 @@ def __init__(self, mean, cov): self.cov = xp.atleast_2d(cov) self.mean = xp.atleast_1d(mean) self.sigma = xp.sqrt(xp.diag(self.cov)) - if xp == np: - self.logpdf = multivariate_normal(mean=self.mean, cov=self.cov).logpdf - else: - from functools import partial - from jax.scipy.stats.multivariate_normal import logpdf - self.logpdf = partial(logpdf, mean=self.mean, cov=self.cov) + try: + self.logpdf = multivariate_logpdf(xp, mean=self.mean, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) + parameters = {"x{0}".format(i): 0 for i in range(self.dim)} super(AnalyticalMultidimensionalCovariantGaussian, self).__init__(parameters=parameters) @@ -606,14 +607,13 @@ def __init__(self, mean_1, mean_2, cov): self.sigma = xp.sqrt(xp.diag(self.cov)) self.mean_1 = xp.atleast_1d(mean_1) self.mean_2 = xp.atleast_1d(mean_2) - if xp == np: - self.logpdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov).logpdf - self.logpdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov).logpdf - else: - from functools import partial - from jax.scipy.stats.multivariate_normal import logpdf - self.logpdf_1 = partial(logpdf, mean=self.mean_1, cov=self.cov) - self.logpdf_2 = partial(logpdf, mean=self.mean_2, cov=self.cov) + try: + self.logpdf_1 = multivariate_logpdf(xp, mean=self.mean_1, cov=self.cov) + self.logpdf_2 = multivariate_logpdf(xp, mean=self.mean_2, cov=self.cov) + except BackendNotImplementedError: + raise NotImplementedError( + f"Multivariate normal likelihood not implemented for {xp.__name__} backend" + ) parameters = {"x{0}".format(i): 0 for i in range(self.dim)} super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__(parameters=parameters) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 755fd3f41..dc418ed4a 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -17,7 +17,8 @@ from .base import Prior from ..utils import logger -from ...compat.utils import array_module, xp_wrap +from ...compat.patches import erfinv_import +from ...compat.utils import BackendNotImplementedError, array_module, xp_wrap class DeltaFunction(Prior): @@ -526,10 +527,10 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - if "jax" in xp.__name__: - from jax.scipy.special import erfinv - else: - from scipy.special import erfinv + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError(f"Gaussian prior rescale not implemented for this {xp.__name__}") return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma @xp_wrap @@ -618,10 +619,12 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - if "jax" in xp.__name__: - from jax.scipy.special import erfinv - else: - from scipy.special import erfinv + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError( + f"Truncated Gaussian prior rescale not implemented for this {xp.__name__}" + ) return erfinv(2 * val * self.normalisation + erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu @@ -716,10 +719,10 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - if "jax" in xp.__name__: - from jax.scipy.special import erfinv - else: - from scipy.special import erfinv + try: + erfinv = erfinv_import(xp) + except BackendNotImplementedError: + raise NotImplementedError(f"LogNormal prior rescale not implemented for this {xp.__name__}") return xp.exp(self.mu + (2 * self.sigma ** 2)**0.5 * erfinv(2 * val - 1)) @xp_wrap diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 7a9a8dcfd..29d0cae0e 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -6,7 +6,6 @@ from warnings import warn import numpy as np -from scipy._lib._array_api import array_namespace from .analytical import DeltaFunction from .base import Prior, Constraint @@ -17,6 +16,7 @@ BilbyJsonEncoder, decode_bilby_json, ) +from ...compat.utils import array_module class PriorDict(dict): @@ -542,7 +542,7 @@ def prob(self, sample, **kwargs): float: Joint probability of all individual sample probabilities """ - xp = array_namespace(*sample.values()) + xp = array_module(sample.values()) prob = xp.prod(xp.asarray([self[key].prob(sample[key]) for key in sample]), **kwargs) return prob @@ -643,7 +643,7 @@ def rescale(self, keys, theta): """ if isinstance(theta, {}.values().__class__): theta = list(theta) - xp = array_namespace(theta) + xp = array_module(theta) return xp.asarray([self[key].rescale(sample) for key, sample in zip(keys, theta)]) @@ -814,7 +814,7 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - xp = array_namespace(*sample.values()) + xp = array_module(sample.values()) res = xp.asarray([ self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample @@ -842,7 +842,7 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) - xp = array_namespace(*sample.values()) + xp = array_module(sample.values()) res = xp.array([ self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample @@ -876,7 +876,7 @@ def rescale(self, keys, theta): """ if isinstance(theta, {}.values().__class__): theta = list(theta) - xp = array_namespace(theta) + xp = array_module(theta) keys = list(keys) self._check_resolved() diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index f20973f4e..6dbce9bf5 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,5 +1,6 @@ import math +import array_api_compat as aac import numpy as np from scipy.interpolate import RectBivariateSpline, interp1d as _interp1d from scipy.special import logsumexp @@ -234,10 +235,17 @@ def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s) def __call__(self, x, y, dx=0, dy=0, grid=False): - from array_api_compat import is_jax_namespace xp = array_module(x) - if is_jax_namespace(xp): + if aac.is_numpy_namespace(xp): + return self._call_scipy(x, y, dx=dx, dy=dy, grid=grid) + elif aac.is_jax_namespace(xp): return self._call_jax(x, y) + else: + raise NotImplementedError( + f"BoundedRectBivariateSpline not implemented for {xp.__name__} backend" + ) + + def _call_scipy(self, x, y, dx=0, dy=0, grid=False): result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) From f54dfa198d28f61e1c342c7cfcba3248d0c49597 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:29:42 -0500 Subject: [PATCH 038/110] MAINT: removed extra ripple code --- bilby/gw/jaxstuff.py | 67 -------------------------------------------- 1 file changed, 67 deletions(-) delete mode 100644 bilby/gw/jaxstuff.py diff --git a/bilby/gw/jaxstuff.py b/bilby/gw/jaxstuff.py deleted file mode 100644 index 6046e51bd..000000000 --- a/bilby/gw/jaxstuff.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Generic dumping ground for jax-specific functions that we need. -This should find a home somewhere down the line, but gives an -idea of how much pain is being added. -""" - -import jax -import jax.numpy as jnp -from ripple.waveforms import IMRPhenomPv2 - - -def bilby_to_ripple_spins( - theta_jn, - phi_jl, - tilt_1, - tilt_2, - phi_12, - a_1, - a_2, -): - iota = theta_jn - spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) - spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) - spin_1z = a_1 * jnp.cos(tilt_1) - spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) - spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) - spin_2z = a_2 * jnp.cos(tilt_2) - return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z - - -wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) - - -def ripple_bbh_relbin( - frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, fiducial, **kwargs, -): - if fiducial == 1: - kwargs["frequencies"] = frequency - else: - kwargs["frequencies"] = kwargs.pop("frequency_bin_edges") - return ripple_bbh( - frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs - ) - - -def ripple_bbh( - frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs, -): - iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 - ) - if "frequencies" in kwargs: - frequencies = kwargs["frequencies"] - elif "minimum_frequency" in kwargs: - frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) - else: - frequencies = frequency - theta = jnp.array([ - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, jnp.array(0.0), phase, iota - ]) - hp, hc = wf_func(frequencies, theta, jax.numpy.array(20.0)) - return dict(plus=hp, cross=hc) - From 0e7fb3e2ca8087094c3c570414d28489b1597888 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:34:15 -0500 Subject: [PATCH 039/110] REFACTOR: make bilby_cython an optional dependency --- bilby/compat/types.py | 6 ++-- bilby/compat/utils.py | 1 - bilby/gw/compat/__init__.py | 6 ++++ bilby/gw/compat/cython.py | 66 +++++++++++++++++++++++++++++++++++++ bilby/gw/geometry.py | 61 ---------------------------------- bilby/gw/time.py | 33 ------------------- requirements.txt | 2 -- 7 files changed, 74 insertions(+), 101 deletions(-) create mode 100644 bilby/gw/compat/cython.py diff --git a/bilby/compat/types.py b/bilby/compat/types.py index 8a3391c44..48c74c29f 100644 --- a/bilby/compat/types.py +++ b/bilby/compat/types.py @@ -1,6 +1,4 @@ -from typing import Union import numpy as np -Real = Union[float, int] -ArrayLike = Union[np.ndarray, list, tuple] - +Real = float | int +ArrayLike = np.ndarray | list | tuple diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 9f83ac9e5..98fae293d 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -21,7 +21,6 @@ def array_module(arr): return np - def promote_to_array(args, backend, skip=None): if skip is None: skip = len(args) diff --git a/bilby/gw/compat/__init__.py b/bilby/gw/compat/__init__.py index 8e2e63c62..16eea2d00 100644 --- a/bilby/gw/compat/__init__.py +++ b/bilby/gw/compat/__init__.py @@ -2,3 +2,9 @@ from .jax import n_leap_seconds except ModuleNotFoundError: pass + + +try: + from .cython import gps_time_to_utc +except ModuleNotFoundError: + pass \ No newline at end of file diff --git a/bilby/gw/compat/cython.py b/bilby/gw/compat/cython.py new file mode 100644 index 000000000..f42301875 --- /dev/null +++ b/bilby/gw/compat/cython.py @@ -0,0 +1,66 @@ +import numpy as np +from bilby_cython import time as _time, geometry as _geometry +from plum import dispatch + +from ...compat.types import Real, ArrayLike + + +@dispatch(precedence=1) +def gps_time_to_utc(gps_time: Real): + return _time.gps_time_to_utc(gps_time) + + +@dispatch(precedence=1) +def greenwich_mean_sidereal_time(gps_time: Real | ArrayLike): + return _time.greenwich_mean_sidereal_time(gps_time) + + +@dispatch(precedence=1) +def greenwich_sidereal_time(gps_time: Real, equation_of_equinoxes: Real): + return _time.greenwich_sidereal_time(gps_time, equation_of_equinoxes) + + +@dispatch(precedence=1) +def n_leap_seconds(gps_time: Real): + return _time.n_leap_seconds(gps_time) + + +@dispatch(precedence=1) +def utc_to_julian_day(utc_time: Real): + return _time.utc_to_julian_day(utc_time) + + +@dispatch(precedence=1) +def calculate_arm(arm_tilt: Real, arm_azimuth: Real, longitude: Real, latitude: Real): + return _geometry.calculate_arm(arm_tilt, arm_azimuth, longitude, latitude) + + +@dispatch(precedence=1) +def detector_tensor(x: ArrayLike, y: ArrayLike): + return _geometry.detector_tensor(x, y) + + +@dispatch(precedence=1) +def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: str): + return _geometry.get_polarization_tensor(ra, dec, time, psi, mode) + + +@dispatch(precedence=1) +def rotation_matrix_from_delta(delta: ArrayLike): + return _geometry.rotation_matrix_from_delta_x(delta) + + +@dispatch(precedence=1) +def time_delay_geocentric(detector1: ArrayLike, detector2: ArrayLike, ra, dec, time): + return _geometry.time_delay_geocentric(detector1, detector2, ra, dec, time) + + +@dispatch(precedence=1) +def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: Real | ArrayLike): + return _geometry.time_delay_from_geocenter(detector1, ra, dec, time) + + +@dispatch(precedence=1) +def zenith_azimuth_to_theta_phi(zenith: Real, azimuth: Real, delta_x: np.ndarray): + theta, phi = _geometry.zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) + return theta, phi % (2 * np.pi) diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index 8b0cf2855..fe5ca027f 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -1,9 +1,6 @@ -import numpy as np from plum import dispatch -from bilby_cython import geometry as _geometry from .time import greenwich_mean_sidereal_time -from ..compat.types import Real, ArrayLike from ..compat.utils import array_module, promote_to_array @@ -198,61 +195,3 @@ def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): theta = xp.arccos(omega[2]) phi = xp.arctan2(omega[1], omega[0]) % (2 * xp.pi) return theta, phi - - - -# @dispatch(precedence=1) -# def antenna_response(detector_tensor: np.ndarray, ra: FloatOrInt, dec: FloatOrInt, time: FloatOrInt, psi: FloatOrInt, mode: str): -# return _geometry.antenna_response(detector_tensor, ra, dec, time, psi, mode) - - -@dispatch(precedence=1) -def calculate_arm(arm_tilt: Real, arm_azimuth: Real, longitude: Real, latitude: Real): - return _geometry.calculate_arm(arm_tilt, arm_azimuth, longitude, latitude) - - -@dispatch(precedence=1) -def detector_tensor(x: ArrayLike, y: ArrayLike): - return _geometry.detector_tensor(x, y) - - -@dispatch(precedence=1) -def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: str): - return _geometry.get_polarization_tensor(ra, dec, time, psi, mode) - - -# @dispatch(precedence=1) -# def get_polarization_tensor_multiple_modes(ra: FloatOrInt, dec: FloatOrInt, time: FloatOrInt, psi: FloatOrInt, modes: list[str]): -# return [geometry.get_polarization_tensor(ra, dec, time, psi, mode) for mode in modes] - - -@dispatch(precedence=1) -def rotation_matrix_from_delta(delta: ArrayLike): - return _geometry.rotation_matrix_from_delta_x(delta) - - -# @dispatch(precedence=1) -# def three_by_three_matrix_contraction(a: ArrayLike, b: ArrayLike): -# return _geometry.three_by_three_matrix_contraction(a, b) - - -@dispatch(precedence=1) -def time_delay_geocentric(detector1: ArrayLike, detector2: ArrayLike, ra, dec, time): - return _geometry.time_delay_geocentric(detector1, detector2, ra, dec, time) - - -@dispatch(precedence=1) -def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: Real): - return _geometry.time_delay_from_geocenter(detector1, ra, dec, time) - - -@dispatch(precedence=1) -def time_delay_from_geocenter(detector1: ArrayLike, ra: Real, dec: Real, time: ArrayLike): - return _geometry.time_delay_from_geocenter_vectorized(detector1, ra, dec, time) - - -@dispatch(precedence=1) -def zenith_azimuth_to_theta_phi(zenith: Real, azimuth: Real, delta_x: np.ndarray): - theta, phi = _geometry.zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) - return theta, phi % (2 * np.pi) - diff --git a/bilby/gw/time.py b/bilby/gw/time.py index 29a2b9c77..9b085edf3 100644 --- a/bilby/gw/time.py +++ b/bilby/gw/time.py @@ -2,9 +2,7 @@ import numpy as np from plum import dispatch -from bilby_cython import time as _time -from ..compat.types import Real, ArrayLike from ..compat.utils import array_module @@ -226,34 +224,3 @@ def utc_to_julian_day(utc_time): """ return utc_time.julian_day - - -@dispatch(precedence=1) -def gps_time_to_utc(gps_time: Real): - return _time.gps_time_to_utc(gps_time) - - -@dispatch(precedence=1) -def greenwich_mean_sidereal_time(gps_time: Real): - return _time.greenwich_mean_sidereal_time(gps_time) - - -@dispatch(precedence=1) -def greenwich_mean_sidereal_time(gps_time: ArrayLike): - return _time.greenwich_mean_sidereal_time_vectorized(gps_time) - - -@dispatch(precedence=1) -def greenwich_sidereal_time(gps_time: Real, equation_of_equinoxes: Real): - return _time.greenwich_sidereal_time(gps_time, equation_of_equinoxes) - - -@dispatch(precedence=1) -def n_leap_seconds(gps_time: Real): - return _time.n_leap_seconds(gps_time) - - -@dispatch(precedence=1) -def utc_to_julian_day(utc_time: Real): - return _time.utc_to_julian_day(utc_time) - diff --git a/requirements.txt b/requirements.txt index 2b3cc4405..3539f45b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -bilby.cython>=0.3.0 dynesty>=2.0.1 emcee corner @@ -11,6 +10,5 @@ dill tqdm h5py attrs -importlib-metadata>=3.6; python_version < '3.10' plum-dispatch array_api_compat From 8adbdbdfa2b738f968997a6ccff50afa89a8f7e3 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 22 Dec 2025 16:39:04 -0500 Subject: [PATCH 040/110] FMT: formatting fixes --- bilby/core/likelihood.py | 2 +- bilby/core/prior/analytical.py | 2 +- bilby/core/prior/dict.py | 1 - bilby/core/utils/calculus.py | 8 ++++---- bilby/gw/compat/jax.py | 8 +++++--- bilby/gw/geometry.py | 2 -- bilby/gw/likelihood/roq.py | 2 -- bilby/gw/time.py | 18 ++++++++---------- test/core/prior/prior_test.py | 2 +- 9 files changed, 20 insertions(+), 25 deletions(-) diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 611885e16..88947b5a2 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -572,7 +572,7 @@ def __init__(self, mean, cov): raise NotImplementedError( f"Multivariate normal likelihood not implemented for {xp.__name__} backend" ) - + parameters = {"x{0}".format(i): 0 for i in range(self.dim)} super(AnalyticalMultidimensionalCovariantGaussian, self).__init__(parameters=parameters) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index dc418ed4a..76262dcf2 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -619,7 +619,7 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This has been analytically solved for this case. """ - try: + try: erfinv = erfinv_import(xp) except BackendNotImplementedError: raise NotImplementedError( diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 29d0cae0e..b732c7082 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -882,7 +882,6 @@ def rescale(self, keys, theta): self._check_resolved() self._update_rescale_keys(keys) result = dict() - joint = dict() for key, index in zip( self.sorted_keys_without_fixed_parameters, self._rescale_indexes ): diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 6dbce9bf5..7b9e9b017 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -192,7 +192,7 @@ def logtrapzexp(lnf, dx): class interp1d(_interp1d): - + def __call__(self, x): from array_api_compat import is_numpy_namespace @@ -201,7 +201,7 @@ def __call__(self, x): return super().__call__(x) else: return self._call_alt(x, xp=xp) - + def _call_alt(self, x, *, xp=np): if isinstance(self.fill_value, tuple): left, right = self.fill_value @@ -244,7 +244,7 @@ def __call__(self, x, y, dx=0, dy=0, grid=False): raise NotImplementedError( f"BoundedRectBivariateSpline not implemented for {xp.__name__} backend" ) - + def _call_scipy(self, x, y, dx=0, dy=0, grid=False): result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) @@ -258,7 +258,7 @@ def _call_scipy(self, x, y, dx=0, dy=0, grid=False): return result.item() else: return result - + def _call_jax(self, x, y): import jax.numpy as jnp from interpax import interp2d diff --git a/bilby/gw/compat/jax.py b/bilby/gw/compat/jax.py index 9b0732112..99277e30a 100644 --- a/bilby/gw/compat/jax.py +++ b/bilby/gw/compat/jax.py @@ -2,7 +2,10 @@ from jax import Array from plum import dispatch -from ..time import LEAP_SECONDS as _LEAP_SECONDS, n_leap_seconds +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) __all__ = ["n_leap_seconds"] @@ -14,5 +17,4 @@ def n_leap_seconds(date: Array): """ Find the number of leap seconds required for the specified date. """ - return n_leap_seconds(date, LEAP_SECONDS) - + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index fe5ca027f..c07ec5c0c 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -148,7 +148,6 @@ def rotation_matrix_from_delta(delta_x): return rotation_3 @ rotation_2 @ rotation_1 - @dispatch def three_by_three_matrix_contraction(a, b): """""" @@ -171,7 +170,6 @@ def time_delay_geocentric(detector1, detector2, ra, dec, time): return omega @ delta_d / speed_of_light - @dispatch def time_delay_from_geocenter(detector1, ra, dec, time): """""" diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 6097557f7..2ccffe1c9 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -457,8 +457,6 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes'] linear_indices = self.waveform_generator.waveform_arguments['linear_indices'] quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] - size_linear = len(linear_indices) - size_quadratic = len(quadratic_indices) h_linear = 0j h_quadratic = 0j for mode in waveform_polarizations['linear']: diff --git a/bilby/gw/time.py b/bilby/gw/time.py index 9b085edf3..51de99e3c 100644 --- a/bilby/gw/time.py +++ b/bilby/gw/time.py @@ -1,5 +1,3 @@ -from typing import Union - import numpy as np from plum import dispatch @@ -125,7 +123,7 @@ def greenwich_mean_sidereal_time(gps_time): ---------- gps_time : float GPS time in seconds. - + Returns ------- float @@ -145,7 +143,7 @@ def greenwich_sidereal_time(gps_time, equation_of_equinoxes): GPS time in seconds. equation_of_equinoxes : float Equation of the equinoxes in seconds. - + Returns ------- float @@ -178,26 +176,26 @@ def n_leap_seconds(gps_time, leap_seconds): GPS time in seconds. leap_seconds : array_like GPS time of leap seconds. - + Returns ------- float - Number of leap seconds + Number of leap seconds """ xp = array_module(gps_time) return xp.sum(gps_time > leap_seconds[:, None], axis=0).squeeze() @dispatch -def n_leap_seconds(gps_time: Union[np.ndarray, float, int]): +def n_leap_seconds(gps_time: np.ndarray | float | int): # noqa F811 """ Calculate the number of leap seconds that have occurred up to a given GPS time. Parameters ---------- - gps_time : float + gps_time : float | np.ndarray | int GPS time in seconds. - + Returns ------- float @@ -216,7 +214,7 @@ def utc_to_julian_day(utc_time): ---------- utc_time : datetime UTC time. - + Returns ------- float diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 14f864e90..43ce08fe5 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -853,7 +853,7 @@ def test_set_minimum_setting(self): continue prior.minimum = (prior.maximum + prior.minimum) / 2 self.assertTrue(min(prior.sample(10000)) > prior.minimum) - + def test_jax_methods(self): import jax From 822e08efc3e1bf917f4273feef6a2e2f501f5fa0 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 21 Jan 2026 16:59:14 -0500 Subject: [PATCH 041/110] BUG: fix array introspection for conversion --- bilby/gw/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 2fbbfdd5c..95687e7fb 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -241,7 +241,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): """ converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) - xp = array_module(parameters[original_keys[5]]) + xp = array_module(parameters.values()) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ From d11b2c4cdcd5cee02495321836bd208d3e41b049 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 21 Jan 2026 16:59:48 -0500 Subject: [PATCH 042/110] REFACTOR: make parameters for waveform generator more strict --- bilby/gw/waveform_generator.py | 35 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 3983833a6..060825e8b 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -78,8 +78,11 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen self.waveform_arguments = waveform_arguments else: self.waveform_arguments = dict() - if isinstance(parameters, dict): - self.parameters = parameters + if parameters is not None: + logger.warning( + "Setting initial parameters via the 'parameters' argument is " + "deprecated and will be removed in a future release." + ) self._cache = dict(parameters=None, waveform=None, model=None) self.use_cache = use_cache logger.info(f"Waveform generator instantiated: {self}") @@ -108,15 +111,13 @@ def __repr__(self): def frequency_domain_strain(self, parameters=None): """ Wrapper to source_model. - Converts self.parameters with self.parameter_conversion before handing it off to the source model. + Converts parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given. Parameters ========== parameters: dict, optional - Parameters to evaluate the waveform for, this overwrites - `self.parameters`. - If not provided will fall back to `self.parameters`. + If not provided will use the last parameters used. Returns ======= @@ -137,16 +138,14 @@ def frequency_domain_strain(self, parameters=None): def time_domain_strain(self, parameters=None): """ Wrapper to source_model. - Converts self.parameters with self.parameter_conversion before handing it off to the source model. + Converts parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is given. Parameters ========== parameters: dict, optional - Parameters to evaluate the waveform for, this overwrites - `self.parameters`. - If not provided will fall back to `self.parameters`. + If not provided will use the last parameters used. Returns ======= @@ -166,11 +165,13 @@ def time_domain_strain(self, parameters=None): def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, transformed_model_data_points, parameters): - if parameters is not None: - self.parameters = parameters + if parameters is None: + parameters = self._cache.get('parameters', None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") if ( self.use_cache - and self.parameters == self._cache['parameters'] + and parameters == self._cache.get('parameters', None) and self._cache['model'] == model and self._cache['transformed_model'] == transformed_model ): @@ -495,7 +496,9 @@ def frequency_domain_strain(self, parameters): from lalsimulation.gwsignal import GenerateFDWaveform if parameters is None: - parameters = self.parameters + parameters = self._cache.get("parameters", None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") hpc = _try_waveform_call( GenerateFDWaveform, @@ -529,7 +532,9 @@ def time_domain_strain(self, parameters): from lalsimulation.gwsignal import GenerateTDWaveform if parameters is None: - parameters = self.parameters + parameters = self._cache.get("parameters", None) + if parameters is None: + raise ValueError("No parameters given to generate waveform.") hpc = _try_waveform_call( GenerateTDWaveform, From 69abdfb69fc075c02c5ec335150e25a79d4597ca Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 21 Jan 2026 17:03:15 -0500 Subject: [PATCH 043/110] BUG: fix core likelihood tests --- bilby/core/likelihood.py | 4 ++-- test/core/likelihood_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 88947b5a2..eb350253f 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -398,9 +398,9 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if any(mu < 0.): - return -np.inf xp = array_module(mu) + if xp.any(mu < 0.): + return -np.inf return -xp.sum(xp.log(mu) + (self.y / mu)) def __repr__(self): diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 8061f9e55..38c7c70b8 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -403,7 +403,7 @@ def test_log_likelihood_dummy(self): poisson_likelihood = PoissonLikelihood( x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N) ) - with mock.patch("numpy.sum") as m: + with mock.patch("array_api_compat.numpy.sum") as m: m.return_value = 1 self.assertEqual(1, poisson_likelihood.log_likelihood()) @@ -495,7 +495,7 @@ def test_log_likelihood_default(self): exponential_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=lambda x: np.array([4.2]) ) - with mock.patch("numpy.sum") as m: + with mock.patch("array_api_compat.numpy.sum") as m: m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood()) From b3c38ba6af4de6fcc36dce4ea6585c100adc04cc Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 11:59:04 -0500 Subject: [PATCH 044/110] BUG: fix calibration calculations --- bilby/gw/detector/calibration.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 229bdba25..28e9c784e 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -44,6 +44,7 @@ import numpy as np import pandas as pd +from array_api_compat import is_jax_namespace from scipy.interpolate import interp1d from ...compat.utils import array_module @@ -375,8 +376,9 @@ def get_calibration_factor(self, frequency_array, **params): delta_amplitude = self._evaluate_spline("amplitude", a, b, c, d, previous_nodes) delta_phase = self._evaluate_spline("phase", a, b, c, d, previous_nodes) calibration_factor = (1 + delta_amplitude) * (2 + 1j * delta_phase) / (2 - 1j * delta_phase) + xp = calibration_factor.__array_namespace__() - return calibration_factor + return xp.nan_to_num(calibration_factor) class Precomputed(Recalibrate): @@ -408,8 +410,21 @@ def get_calibration_factor(self, frequency_array, **params): idx = int(params.get(self.prefix, None)) if idx is None: raise KeyError(f"Calibration index for {self.label} not found.") - if not np.array_equal(frequency_array, self.frequency_array): - raise ValueError("Frequency grid passed to calibrator doesn't match.") + + xp = frequency_array.__array_namespace__() + if not xp.array_equal(frequency_array, self.frequency_array): + intersection, mask, _ = xp.intersect1d( + frequency_array, self.frequency_array, return_indices=True + ) + if len(intersection) != len(self.frequency_array): + raise ValueError("Frequency grid passed to calibrator doesn't match.") + output = xp.ones_like(frequency_array, dtype=complex) + curve = xp.asarray(self.curves[idx]) + if is_jax_namespace(xp): + output = output.at[mask].set(curve) + else: + output[mask] = curve + return output return self.curves[idx] @classmethod From 23479c87a1fcdb02a9251b14484d3cb7a9509a5b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 11:59:27 -0500 Subject: [PATCH 045/110] EXAMPLE: update jax fast tutorial --- .../injection_examples/jax_fast_tutorial.py | 243 +++++++++--------- 1 file changed, 118 insertions(+), 125 deletions(-) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index 6bfd87ea8..b3eb23a55 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -10,96 +10,129 @@ We optionally use ripple waveforms and a JIT-compiled likelihood. """ import os -from itertools import product # Set OMP_NUM_THREADS to stop lalsimulation taking over my computer os.environ["OMP_NUM_THREADS"] = "1" import bilby -import bilby.gw.jaxstuff import numpy as np import jax import jax.numpy as jnp -from jax import random -from numpyro.infer import AIES, ESS # noqa -from numpyro.infer.ensemble_util import get_nondiagonal_indices +from bilby.compat.jax import JittedLikelihood +from ripple.waveforms import IMRPhenomPv2 jax.config.update("jax_enable_x64", True) -bilby.core.utils.setup_logger() # log_level="WARNING") - -def setup_prior(): - # Set up a PriorDict, which inherits from dict. - # By default we will sample all terms in the signal models. However, this will - # take a long time for the calculation, so for this example we will set almost - # all of the priors to be equall to their injected values. This implies the - # prior is a delta function at the true, injected value. In reality, the - # sampler implementation is smart enough to not sample any parameter that has - # a delta-function prior. - # The above list does *not* include mass_1, mass_2, theta_jn and luminosity - # distance, which means those are the parameters that will be included in the - # sampler. If we do nothing, then the default priors get used. - priors = bilby.gw.prior.BBHPriorDict() - del priors["mass_1"], priors["mass_2"] - priors["geocent_time"] = bilby.core.prior.Uniform(1126249642, 1126269642) - priors["luminosity_distance"].minimum = 1 - priors["luminosity_distance"].maximum = 500 - priors["chirp_mass"].minimum = 2.35 - priors["chirp_mass"].maximum = 2.45 - # priors["luminosity_distance"] = bilby.core.prior.PowerLaw(2.0, 10.0, 500.0) - # priors["sky_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["sky_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["sky_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["delta_phase"] = priors.pop("phase") - # del priors["tilt_1"], priors["tilt_2"], priors["phi_12"], priors["phi_jl"] - # priors["spin_1_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["spin_1_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["spin_1_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["spin_2_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["spin_2_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["spin_2_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - # # del priors["a_1"], priors["a_2"] - # # priors["chi_1"] = bilby.core.prior.Uniform(-0.05, 0.05) - # # priors["chi_2"] = bilby.core.prior.Uniform(-0.05, 0.05) - # del priors["theta_jn"], priors["psi"], priors["delta_phase"] - # priors["orientation_w"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["orientation_x"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["orientation_y"] = bilby.core.prior.Normal(mu=0, sigma=1) - # priors["orientation_z"] = bilby.core.prior.Normal(mu=0, sigma=1) - return priors - - -def original_to_sampling_priors(priors, truth): - del priors["ra"], priors["dec"] - priors["zenith"] = bilby.core.prior.Cosine() - priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) - priors["L1_time"] = bilby.core.prior.Uniform(truth["geocent_time"] - 0.1, truth["geocent_time"] + 0.1) +def bilby_to_ripple_spins( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + a_1, + a_2, +): + """ + A simplified spherical to cartesian spin conversion function. + This is not equivalent to the method used in `bilby.gw.conversion` + which comes from `lalsimulation` and is not `JAX` compatible. + """ + iota = theta_jn + spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) + spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) + spin_1z = a_1 * jnp.cos(tilt_1) + spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) + spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) + spin_2z = a_2 * jnp.cos(tilt_2) + return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z + + +def ripple_bbh( + frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, + a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs, +): + """ + Source function wrapper to ripple's IMRPhenomPv2 waveform generator. + This function cannot be jitted directly as the Bilby waveform generator + relies on inspecting the function signature. + + Parameters + ---------- + frequency: jnp.ndarray + Frequencies at which to compute the waveform. + mass_1: float | jnp.ndarray + Mass of the primary component in solar masses. + mass_2: float | jnp.ndarray + Mass of the secondary component in solar masses. + luminosity_distance: float | jnp.ndarray + Luminosity distance to the source in Mpc. + theta_jn: float | jnp.ndarray + Angle between total angular momentum and line of sight in radians. + phase: float | jnp.ndarray + Phase at coalescence in radians. + a_1: float | jnp.ndarray + Dimensionless spin magnitude of the primary component. + a_2: float | jnp.ndarray + Dimensionless spin magnitude of the secondary component. + tilt_1: float | jnp.ndarray + Tilt angle of the primary component spin in radians. + tilt_2: float | jnp.ndarray + Tilt angle of the secondary component spin in radians. + phi_12: float | jnp.ndarray + Azimuthal angle between the two spin vectors in radians. + phi_jl: float | jnp.ndarray + Azimuthal angle of the total angular momentum vector in radians. + **kwargs + Additional keyword arguments. Must include 'minimum_frequency'. + + Returns + ------- + dict + Dictionary containing the plus and cross polarizations of the waveform. + """ + iota, *cartesian_spins = bilby_to_ripple_spins( + # iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 + ) + frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) + theta = jnp.array([ + mass_1, mass_2, *cartesian_spins, + luminosity_distance, jnp.array(0.0), phase, iota + ]) + wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) + hp, hc = wf_func(frequencies, theta, jnp.array(20.0)) + return dict(plus=hp, cross=hc) -def main(use_jax, model, idx): +def main(): # Set the duration and sampling frequency of the data segment that we're # going to inject the signal into duration = 64.0 sampling_frequency = 2048.0 minimum_frequency = 20.0 - if use_jax: - duration = jax.numpy.array(duration) - sampling_frequency = jax.numpy.array(sampling_frequency) - minimum_frequency = jax.numpy.array(minimum_frequency) + duration = jnp.array(duration) + sampling_frequency = jnp.array(sampling_frequency) + minimum_frequency = jnp.array(minimum_frequency) # Specify the output directory and the name of the simulation. - outdir = "pp-test-2" - label = f"{model}_{'jax' if use_jax else 'numpy'}_{idx}" + outdir = "outdir" + label = f"jax_fast_tutorial" # Set up a random seed for result reproducibility. This is optional! - bilby.core.utils.random.seed(88170235 + idx * 1000) + bilby.core.utils.random.seed(88170235) - priors = setup_prior() + priors = bilby.gw.prior.BBHPriorDict() injection_parameters = priors.sample() - if model == "relbin": - injection_parameters["fiducial"] = 1 - original_to_sampling_priors(priors, injection_parameters) + injection_parameters["geocent_time"] = 1000000000.0 + injection_parameters["luminosity_distance"] = 400.0 + del priors["ra"], priors["dec"] + priors["zenith"] = bilby.core.prior.Cosine() + priors["azimuth"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi) + priors["L1_time"] = bilby.core.prior.Uniform( + injection_parameters["geocent_time"] - 0.1, + injection_parameters["geocent_time"] + 0.1, + ) # Fixed arguments passed into the source model waveform_arguments = dict( @@ -108,28 +141,13 @@ def main(use_jax, model, idx): minimum_frequency=minimum_frequency, ) - if use_jax: - match model: - case "relbin": - fdsm = bilby.gw.jaxstuff.ripple_bbh_relbin - case _: - fdsm = bilby.gw.jaxstuff.ripple_bbh - else: - match model: - case "relbin": - fdsm = bilby.gw.source.lal_binary_black_hole_relative_binning - case _: - fdsm = bilby.gw.source.lal_binary_black_hole - # fdsm = bilby.gw.source.sinegaussian - # Create the waveform_generator using a LAL BinaryBlackHole source function waveform_generator = bilby.gw.WaveformGenerator( duration=duration, sampling_frequency=sampling_frequency, - frequency_domain_source_model=fdsm, - # parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, + frequency_domain_source_model=ripple_bbh, waveform_arguments=waveform_arguments, - use_cache=not use_jax, + use_cache=False, ) # Set up interferometers. In this case we'll use two interferometers @@ -145,30 +163,11 @@ def main(use_jax, model, idx): waveform_generator=waveform_generator, parameters=injection_parameters, raise_error=False, ) - if use_jax: - ifos.set_array_backend(jax.numpy) - - if model == "mb": - if use_jax: - pass - else: - waveform_generator.frequency_domain_source_model = ( - bilby.gw.source.binary_black_hole_frequency_sequence - ) - del waveform_generator.waveform_arguments["minimum_frequency"] + ifos.set_array_backend(jnp) # Initialise the likelihood by passing in the interferometer data (ifos) and # the waveform generator - match model: - case "relbin": - likelihood_class = ( - bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient - ) - case "mb": - likelihood_class = bilby.gw.likelihood.MBGravitationalWaveTransient - case _: - likelihood_class = bilby.gw.likelihood.GravitationalWaveTransient - likelihood = likelihood_class( + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=ifos, waveform_generator=waveform_generator, priors=priors, @@ -176,43 +175,37 @@ def main(use_jax, model, idx): distance_marginalization=True, reference_frame=ifos, time_reference="L1", - # epsilon=0.1, - # update_fiducial_parameters=True, ) + # Do an initial likelihood evaluation to trigger any internal setup + likelihood.log_likelihood_ratio(priors.sample()) + # Wrap the likelihood with the JittedLikelihood to JIT compile the likelihood + # evaluation + likelihood = JittedLikelihood(likelihood) + # Evaluate the likelihood once to trigger the JIT compilation, this will take + # a few seconds as compiling the waveform takes some time + likelihood.log_likelihood_ratio(priors.sample()) # use the log_compiles context so we can make sure there aren't recompilations # inside the sampling loop - if True: - # with jax.log_compiles(): + with jax.log_compiles(): result = bilby.run_sampler( likelihood=likelihood, priors=priors, - sampler="jaxted" if use_jax else "dynesty", - nlive=1000, + sampler="dynesty", + nlive=100, sample="acceptance-walk", - method="nest", - nsteps=100, - naccept=30, + naccept=5, injection_parameters=injection_parameters, outdir=outdir, label=label, - npool=None if use_jax else 16, - # save="hdf5", - save=False, + npool=None, + save="hdf5", rseed=np.random.randint(0, 100000), ) # Make a corner plot. - # result.plot_corner() - import IPython; IPython.embed() - return result.sampling_time + result.plot_corner() if __name__ == "__main__": - times = dict() - # for arg in product([True, False][:], ["relbin", "mb", "regular"][2:3]): - # times[arg] = main(*arg) - with jax.log_compiles(): - for idx in np.arange(100): - times[idx] = main(True, "mb", idx) - print(times) + main() From 17114a256bf545423add2085342594146f2114ae Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 11:59:45 -0500 Subject: [PATCH 046/110] TST: refactor marginalization tests to be less restrictive --- test/gw/likelihood/marginalization_test.py | 59 +++++++++++++--------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/test/gw/likelihood/marginalization_test.py b/test/gw/likelihood/marginalization_test.py index 3538a4c58..176e83453 100644 --- a/test/gw/likelihood/marginalization_test.py +++ b/test/gw/likelihood/marginalization_test.py @@ -3,6 +3,7 @@ import pytest import unittest from copy import deepcopy +from functools import cached_property from itertools import product from parameterized import parameterized @@ -230,54 +231,63 @@ def setUp(self): maximum=self.parameters["geocent_time"] + 0.1 ) - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, + minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", - frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"), - frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"), ) ) - self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy" - self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" - self.relbin_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.multiband_waveform_generator = bilby.gw.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, - minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", ) ) - self.multiband_waveform_generator = bilby.gw.WaveformGenerator( + @property + def roq_dir(self): + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) + for path in trial_roq_paths: + if os.path.isdir(path): + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + @property + def roq_linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def roq_quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @cached_property + def roq_waveform_generator(self): + return bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, start_time=1126259640, waveform_arguments=dict( reference_frequency=20.0, waveform_approximant="IMRPhenomPv2", + frequency_nodes_linear=np.load(f"{self.roq_dir}/fnodes_linear.npy"), + frequency_nodes_quadratic=np.load(f"{self.roq_dir}/fnodes_quadratic.npy"), ) ) @@ -287,7 +297,6 @@ def tearDown(self): del self.parameters del self.interferometers del self.waveform_generator - del self.roq_waveform_generator del self.priors @classmethod From a98fb4342aa7f497b137b215f8c664157e55e06f Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 13:31:28 -0500 Subject: [PATCH 047/110] DOC: update jittedlikelihood docstring --- bilby/compat/jax.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bilby/compat/jax.py b/bilby/compat/jax.py index f94f64cd8..af0699147 100644 --- a/bilby/compat/jax.py +++ b/bilby/compat/jax.py @@ -7,11 +7,6 @@ class JittedLikelihood(Likelihood): """ A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. - .. note:: - - This is currently hardcoded to return the log likelihood ratio, regardless of - the input. - Parameters ========== likelihood: bilby.core.likelihood.Likelihood From f47693008ca64ab00f1f330a1564892cc9013ba4 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 14:34:39 -0500 Subject: [PATCH 048/110] TEST: speed up initializing prior tests --- test/core/prior/prior_test.py | 44 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 43ce08fe5..2b0500db8 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -6,9 +6,28 @@ from scipy.integrate import trapezoid +aligned_prior_complex = bilby.gw.prior.AlignedSpin( + a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), + z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), + name="test", + unit="unit", + num_interp=1000, +) + +hp_map_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "prior_files/GW150914_testing_skymap.fits", +) +hp_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec"] +) +hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( + hp_map_file, names=["testra", "testdec", "testdistance"], distance=True +) + + class TestPriorClasses(unittest.TestCase): def setUp(self): - # set multivariate Gaussian mvg = bilby.core.prior.MultivariateGaussianDist( names=["testa", "testb"], @@ -22,16 +41,6 @@ def setUp(self): covs=np.array([[2.0, 0.5], [0.5, 2.0]]), weights=1.0, ) - hp_map_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "prior_files/GW150914_testing_skymap.fits", - ) - hp_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec"] - ) - hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec", "testdistance"], distance=True - ) def condition_func(reference_params, test_param): return reference_params.copy() @@ -102,13 +111,7 @@ def condition_func(reference_params, test_param): name="test", unit="unit", minimum=1e-2, maximum=1e2 ), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), - bilby.gw.prior.AlignedSpin( - a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), - z_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0, minimum=-1), - name="test", - unit="unit", - num_interp=1000, - ), + aligned_prior_complex, bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), bilby.core.prior.MultivariateNormal(dist=mvn, name="testa", unit="unit"), @@ -861,15 +864,16 @@ def test_jax_methods(self): for prior in self.priors: if bilby.core.prior.JointPrior in prior.__class__.__mro__: continue - print(prior) scaled = prior.rescale(points) assert isinstance(scaled, jax.Array) if isinstance(prior, bilby.core.prior.DeltaFunction): continue - assert max(abs(prior.cdf(scaled) - points)) < 1e-6 probs = prior.prob(scaled) assert min(probs) > 0 assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 + if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): + continue + assert max(abs(prior.cdf(scaled) - points)) < 1e-6 if __name__ == "__main__": From e4c96c3a5b379fa757e6c91b3493588958c09121 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 14:34:52 -0500 Subject: [PATCH 049/110] BUG: fix some test failures --- bilby/core/prior/analytical.py | 57 ++++++++++++++++---------------- bilby/core/prior/interpolated.py | 6 ++-- bilby/gw/prior.py | 2 +- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 76262dcf2..ca48e3393 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -13,6 +13,7 @@ stdtr, stdtrit, xlogy, + xlog1py, ) from .base import Prior @@ -369,7 +370,7 @@ def ln_prob(self, val, *, xp=np): @xp_wrap def cdf(self, val, *, xp=np): asymmetric = xp.log(abs(val) / self.minimum) / xp.log(self.maximum / self.minimum) - return 0.5 * (1 + xp.sign(val) * asymmetric) + return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) class Cosine(Prior): @@ -989,7 +990,7 @@ def prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - return xp.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap def ln_prob(self, val, *, xp=np): @@ -1003,13 +1004,9 @@ def ln_prob(self, val, *, xp=np): ======= Union[float, array_like]: Prior probability of val """ - _ln_prob = ( - xlogy(xp.asarray(self.alpha - 1), val - self.minimum) - + xlogy(xp.asarray(self.beta - 1), self.maximum - val) - - betaln(self.alpha, self.beta) - - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) - ) - return xp.nan_to_num(_ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) + ln_prob = xlog1py(self.beta - 1.0, -val) + xlogy(self.alpha - 1.0, val) + ln_prob -= betaln(self.alpha, self.beta) + return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap def cdf(self, val, *, xp=np): @@ -1442,20 +1439,21 @@ def __init__( The unit of the parameter. Used for plotting. """ + xp = array_module(values) nvalues = len(values) - values = np.array(values) + values = xp.array(values) if values.shape != (nvalues,): raise ValueError( f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}" ) - minimum = np.min(values) + minimum = xp.min(values) # Small delta added to help with MCMC walking - maximum = np.max(values) * (1 + 1e-15) + maximum = xp.max(values) * (1 + 1e-15) super(WeightedDiscreteValues, self).__init__( name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) self.nvalues = nvalues - sorter = np.argsort(values) + sorter = xp.argsort(values) self._values_array = values[sorter] # inititialization of priors from repr only supports @@ -1463,9 +1461,9 @@ def __init__( self.values = self._values_array.tolist() weights = ( - np.array(weights) / np.sum(weights) + xp.array(weights) / xp.sum(weights) if weights is not None - else np.ones(self.nvalues) / self.nvalues + else xp.ones(self.nvalues) / self.nvalues ) # check for consistent shape of input if weights.shape != (self.nvalues,): @@ -1476,14 +1474,15 @@ def __init__( ) self._weights_array = weights[sorter] self.weights = self._weights_array.tolist() - self._lnweights_array = np.log(self._weights_array) + self._lnweights_array = xp.log(self._weights_array) # save cdf for rescaling - _cumulative_weights_array = np.cumsum(self._weights_array) + _cumulative_weights_array = xp.cumsum(self._weights_array) # insert 0 for values smaller than minimum - self._cumulative_weights_array = np.insert(_cumulative_weights_array, 0, 0) + self._cumulative_weights_array = xp.insert(_cumulative_weights_array, 0, 0) - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=np): """ 'Rescale' a sample from the unit line element to the discrete-value prior. @@ -1498,10 +1497,11 @@ def rescale(self, val): ======= Union[float, array_like]: Rescaled probability """ - index = np.searchsorted(self._cumulative_weights_array[1:], val) - return self._values_array[index] + index = xp.searchsorted(self._cumulative_weights_array[1:], val) + return xp.asarray(self._values_array)[index] - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=np): """Return the cumulative prior probability of val. Parameters @@ -1512,10 +1512,11 @@ def cdf(self, val): ======= float: cumulative prior probability of val """ - index = np.searchsorted(self._values_array, val, side="right") - return self._cumulative_weights_array[index] + index = xp.searchsorted(self._values_array, val, side="right") + return xp.asarray(self._cumulative_weights_array)[index] - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=np): """Return the prior probability of val. Parameters @@ -1526,9 +1527,9 @@ def prob(self, val): ======= float: Prior probability of val """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - p = np.where(self._values_array[index] == val, self._weights_array[index], 0) + index = xp.searchsorted(self._values_array, val) + index = xp.clip(index, 0, self.nvalues - 1) + p = xp.where(self._values_array[index] == val, self._weights_array[index], 0) # turn 0d numpy array to scalar return p[()] diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 57e04738d..ab03809d1 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -77,10 +77,10 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return self.probability_density(val) + return self.probability_density(val)[()] def cdf(self, val): - return self.cumulative_distribution(val) + return self.cumulative_distribution(val)[()] @xp_wrap def rescale(self, val, *, xp=np): @@ -89,7 +89,7 @@ def rescale(self, val, *, xp=np): This maps to the inverse CDF. This is done using interpolation. """ - return self.inverse_cumulative_distribution(val) + return self.inverse_cumulative_distribution(val)[()] @property def minimum(self): diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 04c1d0db8..3258d37d1 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -512,7 +512,7 @@ def integrand(aa, chi): after performing the integral over spin orientation using a delta function identity. """ - return a_prior.prob(aa) * z_prior.prob(chi / aa) / aa + return a_prior.prob(aa, xp=np) * z_prior.prob(chi / aa, xp=np) / aa self.num_interp = 10_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) From a85838f44ccdec4bf2792a35728e9a81458a9a4d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 14:44:43 -0500 Subject: [PATCH 050/110] BUG: fix conditional+joint prior rescaling --- bilby/core/prior/dict.py | 8 +++++++- test/core/prior/conditional_test.py | 8 +++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index b732c7082..eb02ae841 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -888,7 +888,13 @@ def rescale(self, keys, theta): result[key] = self[key].rescale( theta[index], **self.get_required_variables(key) ) - self[key].least_recently_sampled = result[key] + if isinstance(self[key], JointPrior) and result[key] is not None: + for key, val in zip(self[key].dist.names, result[key]): + self[key].least_recently_sampled = val + result[key] = val + else: + self[key].least_recently_sampled = result[key] + return xp.array([result[key] for key in keys]) def _update_rescale_keys(self, keys): diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 68db12ed7..e7e5ec670 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -349,8 +349,10 @@ def test_rescale_with_joint_prior(self): ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = list(self.test_sample.values()) + ref_variables = ref_variables[:2] + [0.1] + ref_variables[2:] + [0.4] + keys = list(self.test_sample.keys()) + keys = keys[:2] + ["mvgvar_0"] + keys[2:] + ["mvgvar_1"] res = priordict.rescale(keys=keys, theta=ref_variables) self.assertEqual(np.shape(res), (6,)) @@ -360,7 +362,7 @@ def test_rescale_with_joint_prior(self): expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - np.testing.assert_array_equal(expected, res[:4]) + np.testing.assert_array_equal(expected, list(res)[:2] + list(res)[3:5]) def test_cdf(self): """ From 93cda56a3fc57d08dab581d6178ce5dfa65e2cb2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 14:57:03 -0500 Subject: [PATCH 051/110] BUG: fix some gnarly conversion corner cases --- bilby/compat/utils.py | 11 ++++++++++- bilby/gw/conversion.py | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 98fae293d..057f6eb1c 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import numpy as np from array_api_compat import array_namespace @@ -10,7 +12,14 @@ def array_module(arr): try: return array_namespace(arr) except TypeError: - if arr.__class__.__module__ == "builtins": + if isinstance(arr, dict): + try: + return array_namespace(*[val for val in arr.values() if not isinstance(val, str)]) + except TypeError: + return np + elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable): + return array_namespace(arr) + elif arr.__class__.__module__ == "builtins": return np elif arr.__module__.startswith("pandas"): return np diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 95687e7fb..5085ee343 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -241,7 +241,7 @@ def convert_to_lal_binary_black_hole_parameters(parameters): """ converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) - xp = array_module(parameters.values()) + xp = array_module(parameters) if 'luminosity_distance' not in original_keys: if 'redshift' in converted_parameters.keys(): converted_parameters['luminosity_distance'] = \ @@ -2119,7 +2119,7 @@ def generate_spin_parameters(sample): output_sample = sample.copy() output_sample = generate_component_spins(output_sample) - xp = array_module(sample["spin_1z"]) + xp = array_module(sample) output_sample['chi_eff'] = (output_sample['spin_1z'] + output_sample['spin_2z'] * From 0316acc4649fe13db9c724e583ec5fb86b4244f3 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 15:28:47 -0500 Subject: [PATCH 052/110] BUG: fix multiband likelihood --- bilby/gw/detector/calibration.py | 5 +++-- bilby/gw/likelihood/multiband.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 28e9c784e..6f7390bde 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -362,8 +362,9 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - log10f_per_deltalog10f = ( - np.log10(frequency_array) - self.log_spline_points[0] + log10f_per_deltalog10f = np.nan_to_num( + np.log10(frequency_array) - self.log_spline_points[0], + neginf=0.0, ) / self.delta_log_spline_points previous_nodes = np.clip(np.floor(log10f_per_deltalog10f).astype(int), a_min=0, a_max=self.n_points - 2) b = log10f_per_deltalog10f - previous_nodes diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 1ece29154..d4d195e24 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -769,7 +769,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr if self.linear_interpolation: optimal_snr_squared = xp.vdot( - (strain * strain.conjugate()).real, + xp.abs(strain)**2, self.quadratic_coeffs[interferometer.name] ) else: @@ -780,7 +780,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr Mb = self.Mbs[b] if b == 0: optimal_snr_squared += (4. / self.interferometers.duration) * xp.vdot( - (strain[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1].conjugate()).real, + xp.abs(strain[start_idx:end_idx + 1])**2, interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] / interferometer.power_spectral_density_array[Ks:Ke + 1]) else: @@ -790,7 +790,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft(self.wths[interferometer.name][b]) thbc = xp.fft.rfft(self.hbcs[interferometer.name][b]) optimal_snr_squared += (4. / self.Tbhats[b]) * xp.vdot( - thbc * np.conjugate(thbc).real, self.Ibcs[interferometer.name][b]) + xp.abs(thbc)**2, self.Ibcs[interferometer.name][b]) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) From 2c3f8fbde8aa36cb3704375375da5a36b5ce82d5 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 15:48:43 -0500 Subject: [PATCH 053/110] BUG: fix bug in array_namespace check --- bilby/compat/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 057f6eb1c..4b099969b 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -18,7 +18,10 @@ def array_module(arr): except TypeError: return np elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable): - return array_namespace(arr) + try: + return array_namespace(*arr) + except TypeError: + return np elif arr.__class__.__module__ == "builtins": return np elif arr.__module__.startswith("pandas"): From 5d10a8a2eaf17fd52469ce0babcdfc573fe69908 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 15:57:44 -0500 Subject: [PATCH 054/110] TEST: make sure healpix prior doesn't store state between calls --- test/core/prior/prior_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 2b0500db8..bed42cf19 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -42,6 +42,10 @@ def setUp(self): weights=1.0, ) + # need to reset this for the repr test to get equality correct + hp_dist.requested_parameters = {"testra": None, "testdec": None} + hp_3d_dist.requested_parameters = {"testra": None, "testdec": None, "testdistance": None} + def condition_func(reference_params, test_param): return reference_params.copy() From 27f20466a57b1514595716e15028cbd835d987b8 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 16:03:15 -0500 Subject: [PATCH 055/110] FMT: example formatting fixes --- .../injection_examples/jax_fast_tutorial.py | 38 ++++++++++++++----- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index b3eb23a55..6f6654536 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -15,9 +15,9 @@ os.environ["OMP_NUM_THREADS"] = "1" import bilby -import numpy as np import jax import jax.numpy as jnp +import numpy as np from bilby.compat.jax import JittedLikelihood from ripple.waveforms import IMRPhenomPv2 @@ -49,8 +49,19 @@ def bilby_to_ripple_spins( def ripple_bbh( - frequency, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, **kwargs, + frequency, + mass_1, + mass_2, + luminosity_distance, + theta_jn, + phase, + a_1, + a_2, + tilt_1, + tilt_2, + phi_12, + phi_jl, + **kwargs, ): """ Source function wrapper to ripple's IMRPhenomPv2 waveform generator. @@ -85,21 +96,27 @@ def ripple_bbh( Azimuthal angle of the total angular momentum vector in radians. **kwargs Additional keyword arguments. Must include 'minimum_frequency'. - + Returns ------- dict Dictionary containing the plus and cross polarizations of the waveform. """ iota, *cartesian_spins = bilby_to_ripple_spins( - # iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_ripple_spins( theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 ) frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) - theta = jnp.array([ - mass_1, mass_2, *cartesian_spins, - luminosity_distance, jnp.array(0.0), phase, iota - ]) + theta = jnp.array( + [ + mass_1, + mass_2, + *cartesian_spins, + luminosity_distance, + jnp.array(0.0), + phase, + iota, + ] + ) wf_func = jax.jit(IMRPhenomPv2.gen_IMRPhenomPv2) hp, hc = wf_func(frequencies, theta, jnp.array(20.0)) return dict(plus=hp, cross=hc) @@ -160,7 +177,8 @@ def main(): start_time=injection_parameters["geocent_time"] - duration + 2, ) ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters, + waveform_generator=waveform_generator, + parameters=injection_parameters, raise_error=False, ) ifos.set_array_backend(jnp) From 91ee508a5e28792f4905616bf482110095d378f9 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 22 Jan 2026 16:33:32 -0500 Subject: [PATCH 056/110] BUG: make sure indices don't overflow in roq --- bilby/gw/likelihood/roq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 2ccffe1c9..8fdc880c4 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -487,6 +487,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr indices, in_bounds = self._closest_time_indices( ifo_time, self.weights['time_samples']) + indices = xp.clip(indices, 0, len(self.weights['time_samples']) - 1) d_inner_h_tc_array = xp.einsum( 'i,ji->j', xp.conjugate(h_linear), From 5ddf3e334e62c6a09fc20fc3c6f23a1a77cce44a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 08:48:10 -0500 Subject: [PATCH 057/110] BUG: fix multiband time marginalization setup --- bilby/gw/likelihood/multiband.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index d4d195e24..d95a5e569 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -711,20 +711,20 @@ def setup_multibanding_from_weights(self, weights): setattr(self, key, value) def _setup_time_marginalization_multiband(self): - """This overwrites attributes set by _setup_time_marginalization of the base likelihood class""" + self._beam_pattern_reference_time = ( + self.priors['geocent_time'].minimum + self.priors['geocent_time'].maximum + ) / 2 N = self.Nbs[-1] // 2 self._delta_tc = self.durations[0] / N - self._times = \ - self.interferometers.start_time + np.arange(N) * self._delta_tc + self._times = ( + np.arange(N) * self._delta_tc + (self._beam_pattern_reference_time - self.interferometers.start_time) + ) % self.interferometers.duration + self.interferometers.start_time self.time_prior_array = \ self.priors['geocent_time'].prob(self._times) * self._delta_tc # allocate array which is FFTed at each likelihood evaluation self._full_d_h = np.zeros(N, dtype=complex) # idxs to convert full frequency points to banded frequency points, used for filling _full_d_h. self._full_to_multiband = [int(f * self.durations[0]) for f in self.banded_frequency_points] - self._beam_pattern_reference_time = ( - self.priors['geocent_time'].minimum + self.priors['geocent_time'].maximum - ) / 2 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True, parameters=None): """ From 79ae333da91a26c1a510ca4335a17ec854bb19ad Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 08:48:32 -0500 Subject: [PATCH 058/110] BUG: fix roq interpolation for out of bounds sample --- bilby/gw/likelihood/roq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 8fdc880c4..8f81f950b 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -567,9 +567,10 @@ def _interp_five_samples(time_samples, values, time): value: float The value of the function at the input time """ + xp = time_samples.__array_namespace__() r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. r2 = values[2] - 2. * values[3] + values[4] - a = (time_samples[3] - time) / (time_samples[1] - time_samples[0]) + a = (time_samples[3] - time) / xp.maximum(time_samples[1] - time_samples[0], 1e-12) b = 1. - a c = (a**3. - a) / 6. d = (b**3. - b) / 6. From 63b6f30da575b74d18ac63aaaac0ebac1fab116c Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 08:48:47 -0500 Subject: [PATCH 059/110] TYPO: fix typo in jax example --- examples/gw_examples/injection_examples/jax_fast_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index 6f6654536..56b1b4d3a 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -134,7 +134,7 @@ def main(): # Specify the output directory and the name of the simulation. outdir = "outdir" - label = f"jax_fast_tutorial" + label = "jax_fast_tutorial" # Set up a random seed for result reproducibility. This is optional! bilby.core.utils.random.seed(88170235) From 23a3d7929e267aaf9b973bc109e9a3fa387d080a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 08:49:22 -0500 Subject: [PATCH 060/110] REFACTOR: refactor more roq likelihood tests --- test/gw/likelihood_test.py | 126 ++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 58 deletions(-) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 4ee18e6a4..1e5351056 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,6 +1,7 @@ import os import unittest import tempfile +from functools import cached_property from itertools import product from parameterized import parameterized import pytest @@ -289,31 +290,6 @@ def setUp(self): self.duration = 4 self.sampling_frequency = 2048 - # Possible locations for the ROQ: in the docker image, local, or on CIT - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) - fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) - fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) - self.test_parameters = dict( mass_1=36.0, mass_2=36.0, @@ -362,20 +338,6 @@ def setUp(self): self.ifos = ifos - roq_wfg = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.duration, - sampling_frequency=self.sampling_frequency, - frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - waveform_arguments=dict( - frequency_nodes_linear=fnodes_linear, - frequency_nodes_quadratic=fnodes_quadratic, - reference_frequency=20.0, - waveform_approximant="IMRPhenomPv2", - ), - ) - - self.roq_wfg = roq_wfg - self.non_roq = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=ifos, waveform_generator=non_roq_wfg ) @@ -387,33 +349,81 @@ def setUp(self): priors=self.priors.copy(), ) - self.roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - priors=self.priors, - ) - - self.roq_phase = bilby.gw.likelihood.ROQGravitationalWaveTransient( - interferometers=ifos, - waveform_generator=roq_wfg, - linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file, - phase_marginalization=True, - priors=self.priors.copy(), - ) - def tearDown(self): del ( - self.roq, self.non_roq, self.non_roq_phase, - self.roq_phase, self.ifos, self.priors, ) + @property + def roq_dir(self): + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) + print(trial_roq_paths) + for path in trial_roq_paths: + print(path, os.path.isdir(path)) + if os.path.isdir(path): + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + + @property + def linear_matrix_file(self): + return f"{self.roq_dir}/B_linear.npy" + + @property + def quadratic_matrix_file(self): + return f"{self.roq_dir}/B_quadratic.npy" + + @property + def params_file(self): + return f"{self.roq_dir}/params.dat" + + @cached_property + def roq_wfg(self): + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" + fnodes_linear = np.load(fnodes_linear_file).T + fnodes_quadratic = np.load(fnodes_quadratic_file).T + return bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + waveform_arguments=dict( + frequency_nodes_linear=fnodes_linear, + frequency_nodes_quadratic=fnodes_quadratic, + reference_frequency=20.0, + waveform_approximant="IMRPhenomPv2", + ), + ) + + @cached_property + def roq(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + priors=self.priors, + ) + + @cached_property + def roq_phase(self): + return bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=self.ifos, + waveform_generator=self.roq_wfg, + linear_matrix=self.linear_matrix_file, + quadratic_matrix=self.quadratic_matrix_file, + phase_marginalization=True, + priors=self.priors.copy(), + ) + def test_matches_non_roq(self): self.assertLess( abs( From 311ced4f471fc2dd4c5800779cd0834f2788726e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 10:33:20 -0500 Subject: [PATCH 061/110] MAINT: revert new conversions --- bilby/gw/conversion.py | 45 ------------------------------------- bilby/gw/likelihood/base.py | 35 +++++++++++------------------ 2 files changed, 13 insertions(+), 67 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 5085ee343..f3f1e5118 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -179,40 +179,6 @@ def transform_precessing_spins(*args): return lalsim_SimInspiralTransformPrecessingNewInitialConditions(*args) -def convert_orientation_quaternion(parameters): - xp = array_module(parameters["orientation_w"]) - norm = ( - parameters["orientation_w"]**2 - + parameters["orientation_x"]**2 - + parameters["orientation_y"]**2 - + parameters["orientation_z"]**2 - )**0.5 - parameters["theta_jn"] = 2 * xp.arccos( - parameters["orientation_z"] / norm - ) - parameters["psi"] = xp.arctan2( - parameters["orientation_w"], - parameters["orientation_y"] - + parameters["orientation_x"], - ) / 2 - parameters["delta_phase"] = xp.arctan2( - parameters["orientation_y"], - parameters["orientation_x"], - ) / 2 - - -def convert_cartesian(parameters, label): - spin_norm = ( - parameters[f"{label}_x"]**2 - + parameters[f"{label}_y"]**2 - + parameters[f"{label}_z"]**2 - )**0.5 - xp = array_module(spin_norm) - zenith = xp.arccos(parameters[f"{label}_z"] / spin_norm) - azimuth = xp.arctan2(parameters[f"{label}_y"], parameters[f"{label}_x"]) - return zenith, azimuth - - def convert_to_lal_binary_black_hole_parameters(parameters): """ Convert parameters we have into parameters we need. @@ -264,14 +230,6 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters = generate_component_masses(converted_parameters, require_add=False) for idx in ['1', '2']: - if f"spin_{idx}_x" in original_keys: - converted_parameters["tilt_1"], converted_parameters["phi_jl"] = ( - convert_cartesian(converted_parameters, "spin_1") - ) - converted_parameters["tilt_2"], converted_parameters["phi_12"] = ( - convert_cartesian(converted_parameters, "spin_2") - ) - converted_parameters["phi_12"] -= converted_parameters["phi_jl"] key = 'chi_{}'.format(idx) if key in original_keys: if "chi_{}_in_plane".format(idx) in original_keys: @@ -302,9 +260,6 @@ def convert_to_lal_binary_black_hole_parameters(parameters): ) converted_parameters[f"cos_tilt_{idx}"] = 1.0 - if "orientation_w" in original_keys: - convert_orientation_quaternion(converted_parameters) - for key in ["phi_jl", "phi_12"]: if key not in converted_parameters: converted_parameters[key] = 0.0 diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index d5459bf48..94df60d33 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -1113,10 +1113,7 @@ def get_sky_frame_parameters(self, parameters=None): ======= dict: dictionary containing ra, dec, and geocent_time """ - from ..conversion import convert_orientation_quaternion, convert_cartesian parameters = _fallback_to_parameters(self, parameters) - if "orientation_w" in parameters: - convert_orientation_quaternion(parameters) time = parameters.get(f'{self.time_reference}_time', None) if time is None and "geocent_time" in parameters: logger.warning( @@ -1124,26 +1121,20 @@ def get_sky_frame_parameters(self, parameters=None): "Falling back to geocent time" ) if not self.reference_frame == "sky": - if "sky_x" in parameters: - zenith, azimuth = convert_cartesian(parameters, "sky") - elif "zenith" in parameters: - zenith = parameters["zenith"] - azimuth = parameters["azimuth"] - elif "ra" in parameters and "dec" in parameters: - ra = parameters["ra"] - dec = parameters["dec"] - logger.warning( - "Cannot convert from zenith/azimuth to ra/dec falling " - "back to provided ra/dec" - ) - zenith = None - else: - raise KeyError("No sky location parameters recognised") - if zenith is not None: + try: ra, dec = zenith_azimuth_to_ra_dec( - zenith, azimuth, time, self.reference_frame - ) - else: + parameters['zenith'], parameters['azimuth'], + time, self.reference_frame) + except KeyError: + if "ra" in parameters and "dec" in parameters: + ra = parameters["ra"] + dec = parameters["dec"] + logger.warning( + "Cannot convert from zenith/azimuth to ra/dec falling " + "back to provided ra/dec" + ) + else: + raise ra = parameters["ra"] dec = parameters["dec"] if "geocent" not in self.time_reference and f"{self.time_reference}_time" in parameters: From cb9703abab8fce3c35664b13d95a5247466ba457 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 10:38:29 -0500 Subject: [PATCH 062/110] CI: fix selecting only non-windows os --- .github/workflows/basic-install.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/basic-install.yml b/.github/workflows/basic-install.yml index 976fe9255..9d71d0aab 100644 --- a/.github/workflows/basic-install.yml +++ b/.github/workflows/basic-install.yml @@ -51,7 +51,7 @@ jobs: python -c "import bilby.hyper" python -c "import cli_bilby" python test/import_test.py - - if: ${{ matrix.os != "windows-latest" }} + - if: runner.os != 'Windows' run: | for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do ${script} --help; From ad23f4fe5d9762914c66456ee98b67226649c36a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 10:44:12 -0500 Subject: [PATCH 063/110] MAINT: make sure compat subpackages are listed in pyproject --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 145d905d1..7f6ce8fbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,11 +114,13 @@ addopts = [ packages = [ "bilby", "bilby.bilby_mcmc", + "bilby.compat" "bilby.core", "bilby.core.prior", "bilby.core.sampler", "bilby.core.utils", "bilby.gw", + "bilby.gw.compat", "bilby.gw.detector", "bilby.gw.eos", "bilby.gw.likelihood", From 59305686b32bade099849bae24d9b950d9ef87a3 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 10:46:24 -0500 Subject: [PATCH 064/110] TYPO: Fix package list formatting in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7f6ce8fbd..0d77ffbf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ addopts = [ packages = [ "bilby", "bilby.bilby_mcmc", - "bilby.compat" + "bilby.compat", "bilby.core", "bilby.core.prior", "bilby.core.sampler", From 230f623babe9f8ac067c6c78b61a8f964b055455 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 11:00:43 -0500 Subject: [PATCH 065/110] BUG: readd erroneously removed line --- bilby/gw/likelihood/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 94df60d33..310128c46 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -1135,6 +1135,7 @@ def get_sky_frame_parameters(self, parameters=None): ) else: raise + else: ra = parameters["ra"] dec = parameters["dec"] if "geocent" not in self.time_reference and f"{self.time_reference}_time" in parameters: From f65e668d9910336590fde4f67e5d7d81c6b82b59 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 23 Jan 2026 12:16:33 -0500 Subject: [PATCH 066/110] DOC: remove extraneous docstring --- bilby/gw/time.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/bilby/gw/time.py b/bilby/gw/time.py index 51de99e3c..996bad070 100644 --- a/bilby/gw/time.py +++ b/bilby/gw/time.py @@ -172,7 +172,7 @@ def n_leap_seconds(gps_time, leap_seconds): Parameters ---------- - gps_time : float + gps_time : float | np.ndarray | int GPS time in seconds. leap_seconds : array_like GPS time of leap seconds. @@ -188,19 +188,6 @@ def n_leap_seconds(gps_time, leap_seconds): @dispatch def n_leap_seconds(gps_time: np.ndarray | float | int): # noqa F811 - """ - Calculate the number of leap seconds that have occurred up to a given GPS time. - - Parameters - ---------- - gps_time : float | np.ndarray | int - GPS time in seconds. - - Returns - ------- - float - Number of leap seconds - """ xp = array_module(gps_time) return n_leap_seconds(gps_time, xp.array(LEAP_SECONDS)) From 2213038aa166d1e37e31c6e8dbfc8d0c555fe964 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 29 Jan 2026 08:09:09 +0000 Subject: [PATCH 067/110] TEST: fix test failures --- bilby/gw/likelihood/multiband.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 8147ceb83..aeec61387 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -716,11 +716,8 @@ def _setup_time_marginalization_multiband(self): ) / 2 N = self.Nbs[-1] // 2 self._delta_tc = self.durations[0] / N - self._times = ( - np.arange(N) * self._delta_tc + (self._beam_pattern_reference_time - self.interferometers.start_time) - ) % self.interferometers.duration + self.interferometers.start_time - self.time_prior_array = \ - self.priors['geocent_time'].prob(self._times) * self._delta_tc + self._times = self.interferometers.start_time + np.arange(N) * self._delta_tc + self.time_prior_array = self.priors['geocent_time'].prob(self._times) * self._delta_tc # allocate array which is FFTed at each likelihood evaluation self._full_d_h = np.zeros(N, dtype=complex) # idxs to convert full frequency points to banded frequency points, used for filling _full_d_h. From a67b4aecb156abca11e9ed67d886c8ae58806b2e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 11:31:03 +0000 Subject: [PATCH 068/110] TEST: start adding jax tests --- bilby/core/utils/calculus.py | 24 +++++--- bilby/core/utils/samples.py | 7 ++- optional_requirements.txt | 1 + requirements.txt | 3 +- test/conftest.py | 38 ++++++++++++ test/core/likelihood_test.py | 39 +++++++----- test/core/series_test.py | 32 +++++----- test/core/utils_test.py | 114 +++++++++++++++++++++-------------- 8 files changed, 169 insertions(+), 89 deletions(-) diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 7b9e9b017..bf3714caa 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -6,7 +6,7 @@ from scipy.special import logsumexp from .log import logger -from ...compat.utils import array_module +from ...compat.utils import array_module, xp_wrap def derivatives( @@ -154,7 +154,8 @@ def derivatives( return grads -def logtrapzexp(lnf, dx): +@xp_wrap +def logtrapzexp(lnf, dx, *, xp=np): """ Perform trapezium rule integration for the logarithm of a function on a grid. @@ -173,18 +174,23 @@ def logtrapzexp(lnf, dx): lnfdx1 = lnf[:-1] lnfdx2 = lnf[1:] - if isinstance(dx, (int, float)): + + if ( + isinstance(dx, (int, float)) or + (aac.is_array_api_obj(dx) and dx.size == 1) + ): C = np.log(dx / 2.0) - elif isinstance(dx, (list, np.ndarray)): - if len(dx) != len(lnf) - 1: + elif isinstance(dx, (list, xp.ndarray)): + dx = xp.asarray(dx) + if dx.size != len(lnf) - 1: raise ValueError( "Step size array must have length one less than the function length" ) - lndx = np.log(dx) + lndx = xp.log(dx) lnfdx1 = lnfdx1.copy() + lndx lnfdx2 = lnfdx2.copy() + lndx - C = -np.log(2.0) + C = -xp.log(2.0) else: raise TypeError("Step size must be a single value or array-like") @@ -235,7 +241,7 @@ def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s) def __call__(self, x, y, dx=0, dy=0, grid=False): - xp = array_module(x) + xp = array_module([x, y]) if aac.is_numpy_namespace(xp): return self._call_scipy(x, y, dx=dx, dy=dy, grid=grid) elif aac.is_jax_namespace(xp): @@ -269,7 +275,7 @@ def _call_jax(self, x, y): jnp.asarray(self.x), jnp.asarray(self.y), jnp.asarray(self.z), - extrap=self.fill_value, + extrap=self.fill_value if self.fill_value is not None else False, method="cubic2", ) diff --git a/bilby/core/utils/samples.py b/bilby/core/utils/samples.py index a075d6dcd..93fdac0ac 100644 --- a/bilby/core/utils/samples.py +++ b/bilby/core/utils/samples.py @@ -1,3 +1,4 @@ +import array_api_extra as xpx import numpy as np from scipy.special import logsumexp @@ -135,7 +136,7 @@ def reflect(u): u: array-like The input array, modified in place. """ - idxs_even = np.mod(u, 2) < 1 - u[idxs_even] = np.mod(u[idxs_even], 1) - u[~idxs_even] = 1 - np.mod(u[~idxs_even], 1) + idxs_even = (u % 2) < 1 + u = xpx.at(u)[idxs_even].set(u[idxs_even] % 1) + u = xpx.at(u)[~idxs_even].set(1 - (u[~idxs_even] % 1)) return u diff --git a/optional_requirements.txt b/optional_requirements.txt index c10d7908b..f0f2205f6 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,5 +1,6 @@ celerite george +parameterized plotly pytest-requires pytest-rerunfailures diff --git a/requirements.txt b/requirements.txt index 3539f45b3..ead66363d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +array_api_compat +array_api_extra dynesty>=2.0.1 emcee corner @@ -11,4 +13,3 @@ tqdm h5py attrs plum-dispatch -array_api_compat diff --git a/test/conftest.py b/test/conftest.py index d08c38604..6da49894b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,4 @@ +import importlib import pytest @@ -5,6 +6,11 @@ def pytest_addoption(parser): parser.addoption( "--skip-roqs", action="store_true", default=False, help="Skip all tests that require ROQs" ) + parser.addoption( + "--array-backend", + default="numpy", + help="Which array to use for testing", + ) def pytest_configure(config): @@ -17,3 +23,35 @@ def pytest_collection_modifyitems(config, items): for item in items: if "requires_roqs" in item.keywords: item.add_marker(skip_roqs) + if config.getoption("--array-backend") is not None: + array_only = pytest.mark.skip(reason="Only running backend dependent tests") + for item in items: + if "array_backend" not in item.keywords: + item.add_marker(array_only) + + +def _xp(request): + backend = request.config.getoption("--array-backend") + match backend: + case None | "numpy": + import numpy + return numpy + case "jax" | "jax.numpy": + import jax + jax.config.update("jax_enable_x64", True) + return jax.numpy + case _: + try: + importlib.import_module(backend) + except ImportError: + raise ValueError(f"Unknown backend for testing: {backend}") + + +@pytest.fixture +def xp(request): + return _xp(request) + + +@pytest.fixture(scope="class") +def xp_class(request): + request.cls.xp = _xp(request) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 38c7c70b8..0db6b987c 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -2,6 +2,7 @@ from unittest import mock import numpy as np +import pytest import bilby.core.likelihood from bilby.core.likelihood import ( @@ -51,10 +52,12 @@ def test_meta_data(self): self.assertEqual(self.likelihood.meta_data, meta_data) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalytical1DLikelihood(unittest.TestCase): def setUp(self): - self.x = np.arange(start=0, stop=100, step=1) - self.y = np.arange(start=0, stop=100, step=1) + self.x = self.xp.arange(start=0, stop=100, step=1) + self.y = self.xp.arange(start=0, stop=100, step=1) def test_func(x, parameter1, parameter2): return parameter1 * x + parameter2 @@ -80,7 +83,7 @@ def test_init_x(self): self.assertTrue(np.array_equal(self.x, self.analytical_1d_likelihood.x)) def test_set_x_to_array(self): - new_x = np.arange(start=0, stop=50, step=2) + new_x = self.xp.arange(start=0, stop=50, step=2) self.analytical_1d_likelihood.x = new_x self.assertTrue(np.array_equal(new_x, self.analytical_1d_likelihood.x)) @@ -100,7 +103,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.analytical_1d_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(start=0, stop=50, step=2) self.analytical_1d_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.analytical_1d_likelihood.y)) @@ -161,17 +164,20 @@ def test_repr(self): self.assertEqual(expected, repr(self.analytical_1d_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGaussianLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.sigma = 0.1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.array(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.array(2.0), c=self.xp.array(0.0)) def tearDown(self): del self.N @@ -182,34 +188,30 @@ def tearDown(self): def test_known_sigma(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) self.assertEqual(likelihood.sigma, self.sigma) def test_known_array_sigma(self): sigma_array = np.ones(self.N) * self.sigma likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma_array) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) self.assertTrue(type(likelihood.sigma) == type(sigma_array)) # noqa: E721 self.assertTrue(all(likelihood.sigma == sigma_array)) def test_set_sigma_None(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=None) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 self.assertTrue(likelihood.sigma is None) with self.assertRaises(TypeError): - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) def test_sigma_float(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=None) likelihood.parameters["m"] = 2 likelihood.parameters["c"] = 0 likelihood.parameters["sigma"] = 1 - likelihood.log_likelihood() + parameters = self.parameters.copy() + parameters["sigma"] = 1 + likelihood.log_likelihood(parameters) self.assertEqual(likelihood.sigma, 1) def test_sigma_other(self): @@ -224,6 +226,11 @@ def test_repr(self): ) self.assertEqual(expected, repr(likelihood)) + def test_return_class(self): + likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) + logl = likelihood.log_likelihood(self.parameters) + self.assertEqual(logl.__array_namespace__(), self.xp) + class TestStudentTLikelihood(unittest.TestCase): def setUp(self): diff --git a/test/core/series_test.py b/test/core/series_test.py index bf1b19c43..7b85c2bc5 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -1,15 +1,19 @@ import unittest + import numpy as np +import pytest from bilby.core.utils import create_frequency_series, create_time_series from bilby.core.series import CoupledTimeAndFrequencySeries +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCoupledTimeAndFrequencySeries(unittest.TestCase): def setUp(self): - self.duration = 2 - self.sampling_frequency = 4096 - self.start_time = -1 + self.duration = self.xp.array(2.0) + self.sampling_frequency = self.xp.array(4096.0) + self.start_time = self.xp.array(-1.0) self.series = CoupledTimeAndFrequencySeries( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -43,10 +47,10 @@ def test_start_time_from_init(self): self.assertEqual(self.start_time, self.series.start_time) def test_frequency_array_type(self): - self.assertIsInstance(self.series.frequency_array, np.ndarray) + self.assertIsInstance(self.series.frequency_array, self.xp.ndarray) def test_time_array_type(self): - self.assertIsInstance(self.series.time_array, np.ndarray) + self.assertIsInstance(self.series.time_array, self.xp.ndarray) def test_frequency_array_from_init(self): expected = create_frequency_series( @@ -63,8 +67,8 @@ def test_time_array_from_init(self): self.assertTrue(np.array_equal(expected, self.series.time_array)) def test_frequency_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 + new_sampling_frequency = self.xp.array(100.0) + new_duration = self.xp.array(3.0) new_frequency_array = create_frequency_series( sampling_frequency=new_sampling_frequency, duration=new_duration ) @@ -79,9 +83,9 @@ def test_frequency_array_setter(self): self.assertAlmostEqual(self.start_time, self.series.start_time) def test_time_array_setter(self): - new_sampling_frequency = 100 - new_duration = 3 - new_start_time = 4 + new_sampling_frequency = self.xp.array(100.0) + new_duration = self.xp.array(3.0) + new_start_time = self.xp.array(4.0) new_time_array = create_time_series( sampling_frequency=new_sampling_frequency, duration=new_duration, @@ -97,24 +101,24 @@ def test_time_array_setter(self): def test_time_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.array(4) with self.assertRaises(ValueError): _ = self.series.time_array def test_time_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.array(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.time_array def test_frequency_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = 4 + self.series.duration = self.xp.array(4) with self.assertRaises(ValueError): _ = self.series.frequency_array def test_frequency_array_without_duration(self): - self.series.sampling_frequency = 4096 + self.series.sampling_frequency = self.xp.array(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.frequency_array diff --git a/test/core/utils_test.py b/test/core/utils_test.py index df46d6bb3..fdd4afeef 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -1,6 +1,7 @@ import unittest import os +import array_api_extra as xpx import dill import numpy as np from astropy import constants @@ -49,35 +50,39 @@ def test_gravitational_constant(self): self.assertEqual(bilby.core.utils.gravitational_constant, lal.G_SI) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFFT(unittest.TestCase): def setUp(self): - self.sampling_frequency = 10 + self.sampling_frequency = self.xp.array(10) def tearDown(self): del self.sampling_frequency def test_nfft_sine_function(self): - injected_frequency = 2.7324 - duration = 100 - times = utils.create_time_series(self.sampling_frequency, duration) + xp = self.xp + injected_frequency = xp.array(2.7324) + duration = xp.array(100) + times = utils.create_time_series(xp.array(self.sampling_frequency), duration) - time_domain_strain = np.sin(2 * np.pi * times * injected_frequency + 0.4) + time_domain_strain = xp.sin(2 * np.pi * times * injected_frequency + 0.4) frequency_domain_strain, frequencies = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) - frequency_at_peak = frequencies[np.argmax(np.abs(frequency_domain_strain))] + frequency_at_peak = frequencies[xp.argmax(abs(frequency_domain_strain))] self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1) def test_nfft_infft(self): - time_domain_strain = np.random.normal(0, 1, 10) + xp = self.xp + time_domain_strain = xp.array(np.random.normal(0, 1, 10)) frequency_domain_strain, _ = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) new_time_domain_strain = bilby.core.utils.infft( frequency_domain_strain, self.sampling_frequency ) - self.assertTrue(np.allclose(time_domain_strain, new_time_domain_strain)) + self.assertTrue(xp.allclose(time_domain_strain, new_time_domain_strain)) class TestInferParameters(unittest.TestCase): @@ -119,11 +124,13 @@ def test_self_handling_method_as_function(self): self.assertListEqual(expected, actual) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeAndFrequencyArrays(unittest.TestCase): def setUp(self): - self.start_time = 1.3 - self.sampling_frequency = 5 - self.duration = 1.6 + self.start_time = self.xp.array(1.3) + self.sampling_frequency = self.xp.array(5) + self.duration = self.xp.array(1.6) self.frequency_array = utils.create_frequency_series( sampling_frequency=self.sampling_frequency, duration=self.duration ) @@ -141,12 +148,13 @@ def tearDown(self): del self.time_array def test_create_time_array(self): - expected_time_array = np.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) + expected_time_array = self.xp.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) time_array = utils.create_time_series( sampling_frequency=self.sampling_frequency, duration=self.duration, starting_time=self.start_time, ) + self.assertEqual(time_array.__array_namespace__(), self.xp) self.assertTrue(np.allclose(expected_time_array, time_array)) def test_create_frequency_array(self): @@ -164,7 +172,7 @@ def test_get_sampling_frequency_from_time_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_time_array_unequally_sampled(self): - self.time_array[-1] += 0.0001 + self.time_array = xpx.at(self.time_array, -1).set(self.time_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_time_array( self.time_array @@ -190,7 +198,9 @@ def test_get_sampling_frequency_from_frequency_array(self): self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): - self.frequency_array[-1] += 0.0001 + self.frequency_array = xpx.at( + self.frequency_array, -1 + ).set(self.frequency_array[-1] + 0.0001) with self.assertRaises(ValueError): _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array( self.frequency_array @@ -233,34 +243,38 @@ def test_consistency_frequency_array_to_frequency_array(self): def test_illegal_sampling_frequency_and_duration(self): with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): _ = utils.create_time_series( - sampling_frequency=7.7, duration=1.3, starting_time=0 + sampling_frequency=self.xp.array(7.7), + duration=self.xp.array(1.3), + starting_time=self.xp.array(0), ) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestReflect(unittest.TestCase): def test_in_range(self): - xprime = np.array([0.1, 0.5, 0.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.array([0.1, 0.5, 0.9]) + x = self.xp.array([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_one_to_two(self): - xprime = np.array([1.1, 1.5, 1.9]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.array([1.1, 1.5, 1.9]) + x = self.xp.array([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_two_to_three(self): - xprime = np.array([2.1, 2.5, 2.9]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.array([2.1, 2.5, 2.9]) + x = self.xp.array([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_one_to_zero(self): - xprime = np.array([-0.9, -0.5, -0.1]) - x = np.array([0.9, 0.5, 0.1]) + xprime = self.xp.array([-0.9, -0.5, -0.1]) + x = self.xp.array([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_two_to_minus_one(self): - xprime = np.array([-1.9, -1.5, -1.1]) - x = np.array([0.1, 0.5, 0.9]) + xprime = self.xp.array([-1.9, -1.5, -1.1]) + x = self.xp.array([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) @@ -325,6 +339,8 @@ def plot(): self.assertTrue(os.path.isfile(self.filename)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestUnsortedInterp2d(unittest.TestCase): def setUp(self): self.xx = np.linspace(0, 1, 10) @@ -343,36 +359,42 @@ def test_returns_none_for_floats_outside_range(self): self.assertIsNone(self.interpolant(-0.5, 0.5)) def test_returns_float_for_float_and_array(self): - self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray) - self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray) - self.assertIsInstance( - self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray + input_array = self.xp.array(np.random.random(10)) + self.assertEqual(self.interpolant(input_array, 0.5).__array_namespace__(), self.xp) + self.assertEqual( + self.interpolant(input_array, input_array).__array_namespace__(), self.xp ) + self.assertEqual(self.interpolant(0.5, input_array).__array_namespace__(), self.xp) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): - self.interpolant(np.random.random(10), np.random.random(20)) + self.interpolant( + self.xp.array(np.random.random(10)), + self.xp.array(np.random.random(20)), + ) def test_returns_fill_in_correct_place(self): - x_data = np.random.random(10) - y_data = np.random.random(10) - x_data[3] = -1 - self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) + x_data = self.xp.array(np.random.random(10)) + y_data = self.xp.array(np.random.random(10)) + x_data = xpx.at(x_data, 3).set(-1) + self.assertTrue(self.xp.isnan(self.interpolant(x_data, y_data)[3])) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTrapeziumRuleIntegration(unittest.TestCase): def setUp(self): - self.x = np.linspace(0, 1, 100) - self.dxs = np.diff(self.x) + self.x = self.xp.linspace(0, 1, 100) + self.dxs = self.xp.diff(self.x) self.dx = self.dxs[0] with np.errstate(divide="ignore"): - self.lnfunc1 = np.log(self.x) + self.lnfunc1 = self.xp.log(self.x) self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 with np.errstate(divide="ignore"): - self.lnfunc2 = np.log(self.x ** 2) + self.lnfunc2 = self.xp.log(self.x ** 2) self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 - self.irregularx = np.array( + self.irregularx = self.xp.array( [ self.x[0], self.x[12], @@ -390,9 +412,9 @@ def setUp(self): ] ) with np.errstate(divide="ignore"): - self.lnfunc1irregular = np.log(self.irregularx) - self.lnfunc2irregular = np.log(self.irregularx ** 2) - self.irregulardxs = np.diff(self.irregularx) + self.lnfunc1irregular = self.xp.log(self.irregularx) + self.lnfunc2irregular = self.xp.log(self.irregularx ** 2) + self.irregulardxs = self.xp.diff(self.irregularx) def test_incorrect_step_type(self): with self.assertRaises(TypeError): @@ -407,19 +429,19 @@ def test_integral_func1(self): res2 = utils.logtrapzexp(self.lnfunc1, self.dxs) self.assertTrue(np.abs(res1 - res2) < 1e-12) - self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res1) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2(self): res = utils.logtrapzexp(self.lnfunc2, self.dxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-4) def test_integral_func1_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12) + self.assertTrue(np.abs((self.xp.exp(res) - self.func1int) / self.func1int) < 1e-12) def test_integral_func2_irregular_steps(self): res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs) - self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2) + self.assertTrue(np.abs((self.xp.exp(res) - self.func2int) / self.func2int) < 1e-2) class TestSavingNumpyRandomGenerator(unittest.TestCase): From 080df9d206deca9145b61cecfbf9ce1a2e9c089e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 11:33:11 +0000 Subject: [PATCH 069/110] CI: add jax tests to CI --- .github/workflows/unit-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index efaa94c29..43316faac 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -58,6 +58,9 @@ jobs: - name: Run unit tests run: | pytest --cov=bilby --durations 10 + - name: Run jax-backend unit tests + run: | + pytest --array-backend jax --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v From 164bc7015e2bff80cc6d5935a3aac3b97b66d251 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 12:34:40 +0000 Subject: [PATCH 070/110] MAINT: add jax extras option --- .github/workflows/unit-tests.yml | 1 + pyproject.toml | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 9aa94e2ff..395248530 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -62,6 +62,7 @@ jobs: python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run jax-backend unit tests run: | + python -m pip install .[jax] pytest --array-backend jax --durations 10 - name: Run sampler tests run: | diff --git a/pyproject.toml b/pyproject.toml index 0d77ffbf1..b2ccdb444 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,11 +135,13 @@ dependencies = {file = ["requirements.txt"]} [tool.setuptools.dynamic.optional-dependencies] all = {file = [ "gw_requirements.txt", + "jax_requirements.txt", "mcmc_requirements.txt", "sampler_requirements.txt", "optional_requirements.txt" ]} gw = {file = ["gw_requirements.txt"]} +jax = {file = ["jax_requirements.txt"]} mcmc = {file = ["mcmc_requirements.txt"]} [tool.setuptools.package-data] From 2205fc2b68754ce93d2c16e32753e56e2f50b278 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 12:34:59 +0000 Subject: [PATCH 071/110] Some more jax testing updates --- bilby/core/likelihood.py | 30 ++++---- bilby/core/utils/calculus.py | 2 +- test/core/likelihood_test.py | 143 ++++++++++++++++++----------------- 3 files changed, 90 insertions(+), 85 deletions(-) diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index eb350253f..6dc025510 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -350,13 +350,13 @@ def log_likelihood(self, parameters=None): raise ValueError( "Poisson rate function returns wrong value type! " "Is {} when it should be numpy.ndarray".format(type(rate))) - elif any(rate < 0.): + xp = rate.__array_namespace__() + if xp.any(rate < 0.): raise ValueError(("Poisson rate function returns a negative", " value!")) - elif any(rate == 0.): + elif xp.any(rate == 0.): return -np.inf else: - xp = array_module(rate) return xp.sum(-rate + self.y * xp.log(rate) - gammaln(self.y + 1)) def __repr__(self): @@ -369,10 +369,11 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = y.__array_namespace__() # check array is a non-negative integer array - if y.dtype.kind not in 'ui' or np.any(y < 0): + if y.dtype.kind not in 'ui' or xp.any(y < 0): raise ValueError("Data must be non-negative integers") self.__y = y @@ -398,7 +399,7 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - xp = array_module(mu) + xp = mu.__array_namespace__() if xp.any(mu < 0.): return -np.inf return -xp.sum(xp.log(mu) + (self.y / mu)) @@ -413,9 +414,10 @@ def y(self): @y.setter def y(self, y): - if not isinstance(y, np.ndarray): - y = np.array([y]) - if any(y < 0): + if not is_array_api_obj(y): + y = np.atleast_1d(y) + xp = y.__array_namespace__() + if xp.any(y < 0): raise ValueError("Data must be non-negative") self._y = y @@ -573,8 +575,7 @@ def __init__(self, mean, cov): f"Multivariate normal likelihood not implemented for {xp.__name__} backend" ) - parameters = {"x{0}".format(i): 0 for i in range(self.dim)} - super(AnalyticalMultidimensionalCovariantGaussian, self).__init__(parameters=parameters) + super(AnalyticalMultidimensionalCovariantGaussian, self).__init__() @property def dim(self): @@ -614,8 +615,7 @@ def __init__(self, mean_1, mean_2, cov): raise NotImplementedError( f"Multivariate normal likelihood not implemented for {xp.__name__} backend" ) - parameters = {"x{0}".format(i): 0 for i in range(self.dim)} - super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__(parameters=parameters) + super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__() @property def dim(self): @@ -624,7 +624,7 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) xp = array_module(self.cov) - x = xp.array([self.parameters["x{0}".format(i)] for i in range(self.dim)]) + x = xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x)) diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index bf3714caa..13e1b2586 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -174,7 +174,7 @@ def logtrapzexp(lnf, dx, *, xp=np): lnfdx1 = lnf[:-1] lnfdx2 = lnf[1:] - + if ( isinstance(dx, (int, float)) or (aac.is_array_api_obj(dx) and dx.size == 1) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 0db6b987c..80d1324c8 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import array_api_extra as xpx import bilby.core.likelihood from bilby.core.likelihood import ( @@ -206,12 +207,7 @@ def test_set_sigma_None(self): def test_sigma_float(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=None) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - likelihood.parameters["sigma"] = 1 - parameters = self.parameters.copy() - parameters["sigma"] = 1 - likelihood.log_likelihood(parameters) + likelihood.sigma = 1.0 self.assertEqual(likelihood.sigma, 1) def test_sigma_other(self): @@ -232,18 +228,21 @@ def test_return_class(self): self.assertEqual(logl.__array_namespace__(), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestStudentTLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.nu = self.N - 2 self.sigma = 1 - self.x = np.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + np.random.normal(0, self.sigma, self.N) + self.x = self.xp.linspace(0, 1, self.N) + self.y = 2 * self.x + 1 + self.xp.array(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function + self.parameters = dict(m=self.xp.array(2.0), c=self.xp.array(0.0)) def tearDown(self): del self.N @@ -256,44 +255,31 @@ def test_known_sigma(self): likelihood = StudentTLikelihood( self.x, self.y, self.function, self.nu, self.sigma ) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) self.assertEqual(likelihood.sigma, self.sigma) def test_set_nu_none(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 self.assertTrue(likelihood.nu is None) def test_log_likelihood_nu_none(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - with self.assertRaises((ValueError, TypeError)): - # ValueError in Python2, TypeError in Python3 - likelihood.log_likelihood() + with self.assertRaises(TypeError): + likelihood.log_likelihood(self.parameters) def test_log_likelihood_nu_zero(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=0) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 with self.assertRaises(ValueError): - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) def test_log_likelihood_nu_negative(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=-1) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 with self.assertRaises(ValueError): - likelihood.log_likelihood() + likelihood.log_likelihood(self.parameters) def test_setting_nu_positive_does_not_change_class_attribute(self): likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=None) - likelihood.parameters["m"] = 2 - likelihood.parameters["c"] = 0 - likelihood.parameters["nu"] = 98 + likelihood.nu = 98 self.assertEqual(likelihood.nu, 98) def test_lam(self): @@ -313,25 +299,28 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPoissonLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.poisson(self.mu, self.N) - self.yfloat = np.copy(self.y) * 1.0 - self.yneg = np.copy(self.y) - self.yneg[0] = -1 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.array(np.random.poisson(self.mu, self.N)) + self.yfloat = self.y.copy() * 1.0 + self.yneg = self.y.copy() + self.yneg = xpx.at(self.yneg, 0).set(-1) def test_function(x, c): return c def test_function_array(x, c): - return np.ones(len(x)) * c + return self.xp.ones(len(x)) * c self.function = test_function self.function_array = test_function_array self.poisson_likelihood = PoissonLikelihood(self.x, self.y, self.function) + self.bad_parameters = dict(c=self.xp.array(-2.0)) def tearDown(self): del self.N @@ -353,23 +342,21 @@ def test_init__y_negative(self): PoissonLikelihood(self.x, self.yneg, self.function) def test_neg_rate(self): - self.poisson_likelihood.parameters["c"] = -2 with self.assertRaises(ValueError): - self.poisson_likelihood.log_likelihood() + self.poisson_likelihood.log_likelihood(self.bad_parameters) def test_neg_rate_array(self): likelihood = PoissonLikelihood(self.x, self.y, self.function_array) - likelihood.parameters["c"] = -2 with self.assertRaises(ValueError): - likelihood.log_likelihood() + likelihood.log_likelihood(self.bad_parameters) def test_init_y(self): - self.assertTrue(np.array_equal(self.y, self.poisson_likelihood.y)) + self.assertTrue(self.xp.array_equal(self.y, self.poisson_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(start=0, stop=50, step=2) self.poisson_likelihood.y = new_y - self.assertTrue(np.array_equal(new_y, self.poisson_likelihood.y)) + self.assertTrue(self.xp.array_equal(new_y, self.poisson_likelihood.y)) def test_set_y_to_positive_int(self): new_y = 5 @@ -394,25 +381,25 @@ def test_log_likelihood_wrong_func_return_type(self): def test_log_likelihood_negative_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, -2]) + x=self.x, y=self.y, func=lambda x: self.xp.array([3, 6, -2]) ) with self.assertRaises(ValueError): poisson_likelihood.log_likelihood() def test_log_likelihood_zero_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, 0]) + x=self.x, y=self.y, func=lambda x: self.xp.array([3, 6, 0]) ) self.assertEqual(-np.inf, poisson_likelihood.log_likelihood()) def test_log_likelihood_dummy(self): """ Merely tests if it goes into the right if else bracket """ poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N) + x=self.x, y=self.y, func=lambda x: self.xp.linspace(1, 100, self.N) ) - with mock.patch("array_api_compat.numpy.sum") as m: + with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 1 - self.assertEqual(1, poisson_likelihood.log_likelihood()) + self.assertEqual(1, poisson_likelihood.log_likelihood(dict(c=5))) def test_repr(self): likelihood = PoissonLikelihood(self.x, self.y, self.function) @@ -422,26 +409,29 @@ def test_repr(self): self.assertEqual(expected, repr(likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestExponentialLikelihood(unittest.TestCase): def setUp(self): self.N = 100 self.mu = 5 - self.x = np.linspace(0, 1, self.N) - self.y = np.random.exponential(self.mu, self.N) - self.yneg = np.copy(self.y) - self.yneg[0] = -1.0 + self.x = self.xp.linspace(0, 1, self.N) + self.y = self.xp.array(np.random.exponential(self.mu, self.N)) + self.yneg = self.y.copy() + self.yneg = xpx.at(self.yneg, 0).set(-1.0) def test_function(x, c): return c def test_function_array(x, c): - return c * np.ones(len(x)) + return c * self.xp.ones(len(x)) self.function = test_function self.function_array = test_function_array self.exponential_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=self.function ) + self.bad_parameters = dict(c=self.xp.array(-1.0)) def tearDown(self): del self.N @@ -458,19 +448,17 @@ def test_negative_data(self): def test_negative_function(self): likelihood = ExponentialLikelihood(self.x, self.y, self.function) - likelihood.parameters["c"] = -1 - self.assertEqual(likelihood.log_likelihood(), -np.inf) + self.assertEqual(likelihood.log_likelihood(self.bad_parameters), -np.inf) def test_negative_array_function(self): likelihood = ExponentialLikelihood(self.x, self.y, self.function_array) - likelihood.parameters["c"] = -1 - self.assertEqual(likelihood.log_likelihood(), -np.inf) + self.assertEqual(likelihood.log_likelihood(self.bad_parameters), -np.inf) def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.exponential_likelihood.y)) def test_set_y_to_array(self): - new_y = np.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(start=0, stop=50, step=2) self.exponential_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.exponential_likelihood.y)) @@ -495,14 +483,14 @@ def test_set_y_to_negative_float(self): def test_set_y_to_nd_array_with_negative_element(self): with self.assertRaises(ValueError): - self.exponential_likelihood.y = np.array([4.3, -1.2, 4]) + self.exponential_likelihood.y = self.xp.array([4.3, -1.2, 4]) def test_log_likelihood_default(self): """ Merely tests that it ends up at the right place in the code """ exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([4.2]) + x=self.x, y=self.y, func=lambda x: self.xp.array([4.2]) ) - with mock.patch("array_api_compat.numpy.sum") as m: + with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood()) @@ -513,15 +501,21 @@ def test_repr(self): self.assertEqual(expected, repr(self.exponential_likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean = [10, 11, 12] + if self.xp != np: + self.cov = self.xp.array(self.cov) + self.sigma = self.xp.array(self.sigma) + self.mean = self.xp.array(self.mean) self.likelihood = AnalyticalMultidimensionalCovariantGaussian( mean=self.mean, cov=self.cov ) - self.likelihood.parameters.update({f"x{ii}": 0 for ii in range(len(self.sigma))}) + self.parameters = {f"x{ii}": 0 for ii in range(len(self.sigma))} def tearDown(self): del self.cov @@ -538,27 +532,38 @@ def test_mean(self): def test_sigma(self): self.assertTrue(np.array_equal(self.sigma, self.likelihood.sigma)) - def test_parameters(self): - self.assertDictEqual(dict(x0=0, x1=0, x2=0), self.likelihood.parameters) - def test_dim(self): self.assertEqual(3, self.likelihood.dim) def test_log_likelihood(self): - likelihood = AnalyticalMultidimensionalCovariantGaussian(mean=[0], cov=[1]) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + likelihood = AnalyticalMultidimensionalCovariantGaussian( + mean=self.xp.array([0]), cov=self.xp.array([1]) + ) + logl = likelihood.log_likelihood(dict(x0=self.xp.array(0.0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.array(0.0))), + ) + self.assertEqual(logl.__array_namespace__(), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAnalyticalMultidimensionalBimodalCovariantGaussian(unittest.TestCase): def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean_1 = [10, 11, 12] self.mean_2 = [20, 21, 22] + if self.xp != np: + self.cov = self.xp.array(self.cov) + self.sigma = self.xp.array(self.sigma) + self.mean_1 = self.xp.array(self.mean_1) + self.mean_2 = self.xp.array(self.mean_2) self.likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=self.mean_1, mean_2=self.mean_2, cov=self.cov ) - self.likelihood.parameters.update({f"x{ii}": 0 for ii in range(len(self.sigma))}) + self.parameters = {f"x{ii}": 0 for ii in range(len(self.sigma))} def tearDown(self): del self.cov @@ -579,9 +584,6 @@ def test_mean_2(self): def test_sigma(self): self.assertTrue(np.array_equal(self.sigma, self.likelihood.sigma)) - def test_parameters(self): - self.assertDictEqual(dict(x0=0, x1=0, x2=0), self.likelihood.parameters) - def test_dim(self): self.assertEqual(3, self.likelihood.dim) @@ -589,7 +591,10 @@ def test_log_likelihood(self): likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=[0], mean_2=[0], cov=[1] ) - self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) + self.assertEqual( + -np.log(2 * np.pi) / 2, + likelihood.log_likelihood(dict(x0=self.xp.array(0.0))), + ) class TestJointLikelihood(unittest.TestCase): From aec63afcff33651773ec4af3d6e1aa30101d64b0 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 12:42:02 +0000 Subject: [PATCH 072/110] MAINT: actually add jax requirements --- jax_requirements.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 jax_requirements.txt diff --git a/jax_requirements.txt b/jax_requirements.txt new file mode 100644 index 000000000..b325586a3 --- /dev/null +++ b/jax_requirements.txt @@ -0,0 +1,2 @@ +interpax +jax \ No newline at end of file From cc79c5411af3a288a001ab116bb234e7ea0ff833 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 12:55:10 +0000 Subject: [PATCH 073/110] CI: don't trivially skip all tests... --- test/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/conftest.py b/test/conftest.py index 6da49894b..6efaa82e2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,7 +8,7 @@ def pytest_addoption(parser): ) parser.addoption( "--array-backend", - default="numpy", + default=None, help="Which array to use for testing", ) From 9d4e01a3bc213972a0fc2aaac84155a1dd5e659b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sat, 31 Jan 2026 13:40:14 +0000 Subject: [PATCH 074/110] Initial pass at making grid work with jax --- bilby/core/grid.py | 80 +++++++++++++++++++++++++----------- bilby/core/utils/calculus.py | 2 +- test/core/grid_test.py | 44 +++++++++++++------- 3 files changed, 86 insertions(+), 40 deletions(-) diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 0d103d4cc..2c09f75bd 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -10,6 +10,7 @@ BilbyJsonEncoder, load_json, move_old_file ) from .result import FileMovedError +from ..compat.utils import array_module def grid_file_name(outdir, label, gzip=False): @@ -36,8 +37,11 @@ def grid_file_name(outdir, label, gzip=False): class Grid(object): - def __init__(self, likelihood=None, priors=None, grid_size=101, - save=False, label='no_label', outdir='.', gzip=False): + def __init__( + self, likelihood=None, priors=None, grid_size=101, + save=False, label='no_label', outdir='.', gzip=False, + xp=None, + ): """ Parameters @@ -58,8 +62,16 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, The output directory to which the grid will be saved gzip: bool Set whether to gzip the output grid file + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`cupy`). If :code:`None`, defaults to :code:`numpy`. + """ + if xp is None: + xp = np + logger.debug("No array module given for grid, defaulting to numpy.") + if priors is None: priors = dict() self.likelihood = likelihood @@ -68,13 +80,15 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, self.parameter_names = list(self.priors.keys()) self.sample_points = dict() - self._get_sample_points(grid_size) + self._get_sample_points(grid_size, xp=xp) # evaluate the prior on the grid points if self.n_dims > 0: self._ln_prior = self.priors.ln_prob( {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)}, axis=0).reshape( self.mesh_grid[0].shape) + else: + self._ln_prior = xp.array(0.0) self._ln_likelihood = None # evaluate the likelihood on the grid points @@ -97,12 +111,14 @@ def ln_prior(self): @property def prior(self): - return np.exp(self.ln_prior) + lnp = self.ln_prior + xp = array_module(lnp) + return xp.exp(lnp) @property def ln_likelihood(self): if self._ln_likelihood is None: - self._evaluate() + self._evaluate(xp=array_module(self._ln_prior)) return self._ln_likelihood @property @@ -116,7 +132,8 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. parameters: list, str A list, or single string, of parameters to marginalize over. If None then all parameters will be marginalized over. @@ -166,7 +183,8 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): Parameters ========== log_array: array_like - A :class:`numpy.ndarray` of log likelihood/posterior values. + A :code:`Python` array-api compatible array of log + likelihood/posterior values. name: str The name of the parameter to marginalize over. non_marg_names: list @@ -189,17 +207,19 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): non_marg_names.remove(name) places = self.sample_points[name] + xp = log_array.__array_namespace__() + print(xp) if len(places) > 1: - dx = np.diff(places) - out = np.apply_along_axis( + dx = xp.diff(places) + out = xp.apply_along_axis( logtrapzexp, axis, log_array, dx ) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape - q = np.arange(0, len(z)).astype(int) != axis - out = np.reshape(log_array, tuple((np.array(list(z)))[q])) + q = xp.arange(0, len(z)).astype(int) != axis + out = xp.reshape(log_array, tuple((xp.array(list(z)))[q])) return out @@ -277,8 +297,9 @@ def marginalize_likelihood(self, parameters=None, not_parameters=None): """ ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) + xp = ln_like.__array_namespace__() # NOTE: the output will not be properly normalised - return np.exp(ln_like - np.max(ln_like)) + return xp.exp(ln_like - xp.max(ln_like)) def marginalize_posterior(self, parameters=None, not_parameters=None): """ @@ -301,20 +322,31 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): ln_post = self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised - return np.exp(ln_post - np.max(ln_post)) + xp = ln_post.__array_namespace__() + return xp.exp(ln_post - xp.max(ln_post)) def _evaluate(self): - self._ln_likelihood = np.empty(self.mesh_grid[0].shape) - self._evaluate_recursion(0, parameters=dict()) + xp = self.mesh_grid[0].__array_namespace__() + if xp.__name__ == "jax.numpy": + from jax import vmap + self._ln_likelihood = vmap(self.likelihood.log_likelihood)( + {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)} + ).reshape(self.mesh_grid[0].shape) + print(type(self._ln_likelihood)) + + else: + self._ln_likelihood = xp.empty(self.mesh_grid[0].shape) + self._evaluate_recursion(0, parameters=dict()) self.ln_noise_evidence = self.likelihood.noise_log_likelihood() def _evaluate_recursion(self, dimension, parameters): if dimension == self.n_dims: - current_point = tuple([[int(np.where( + xp = self.mesh_grid[0].__array_namespace__() + current_point = tuple([[xp.where( parameters[name] == - self.sample_points[name])[0])] for name in self.parameter_names]) - self._ln_likelihood[current_point] = _safe_likelihood_call( - self.likelihood, parameters + self.sample_points[name])[0].item()] for name in self.parameter_names]) + self._ln_likelihood[current_point] = ( +_safe_likelihood_call(self.likelihood, parameters) ) else: name = self.parameter_names[dimension] @@ -322,29 +354,29 @@ def _evaluate_recursion(self, dimension, parameters): parameters[name] = self.sample_points[name][ii] self._evaluate_recursion(dimension + 1, parameters) - def _get_sample_points(self, grid_size): + def _get_sample_points(self, grid_size, *, xp=np): for ii, key in enumerate(self.parameter_names): if isinstance(self.priors[key], Prior): if isinstance(grid_size, int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size)) + xp.linspace(0, 1, grid_size)) elif isinstance(grid_size, list): if isinstance(grid_size[ii], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[ii])) + xp.linspace(0, 1, grid_size[ii])) else: self.sample_points[key] = grid_size[ii] elif isinstance(grid_size, dict): if isinstance(grid_size[key], int): self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[key])) + xp.linspace(0, 1, grid_size[key])) else: self.sample_points[key] = grid_size[key] else: raise TypeError("Unrecognized 'grid_size' type") # set the mesh of points - self.mesh_grid = np.meshgrid( + self.mesh_grid = xp.meshgrid( *(self.sample_points[key] for key in self.parameter_names), indexing='ij') diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 13e1b2586..2cc2b6ae1 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -194,7 +194,7 @@ def logtrapzexp(lnf, dx, *, xp=np): else: raise TypeError("Step size must be a single value or array-like") - return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)]) + return C + logsumexp(xp.array([logsumexp(lnfdx1), logsumexp(lnfdx2)])) class interp1d(_interp1d): diff --git a/test/core/grid_test.py b/test/core/grid_test.py index 82e44c5cc..bc17b9ce7 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -1,30 +1,38 @@ import unittest -import numpy as np import shutil import os -from scipy.stats import multivariate_normal + +import numpy as np +import pytest import bilby +from bilby.compat.patches import multivariate_logpdf + +import jax +from functools import partial -# set 2D multivariate Gaussian likelihood class MultiGaussian(bilby.Likelihood): - def __init__(self, mean, cov): - super(MultiGaussian, self).__init__(parameters=dict()) - self.cov = np.array(cov) - self.mean = np.array(mean) - self.sigma = np.sqrt(np.diag(self.cov)) - self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) + # set 2D multivariate Gaussian likelihood + def __init__(self, mean, cov, *, xp=np): + super(MultiGaussian, self).__init__() + self.xp = xp + self.cov = xp.array(cov) + self.mean = xp.array(mean) + self.sigma = xp.sqrt(xp.diag(self.cov)) + self.logpdf = multivariate_logpdf(xp=xp, mean=self.mean, cov=self.cov) @property def dim(self): return len(self.cov[0]) - def log_likelihood(self): - x = np.array([self.parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.pdf.logpdf(x) + def log_likelihood(self, parameters): + x = self.xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + return self.logpdf(x) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGrid(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(7) @@ -33,7 +41,7 @@ def setUp(self): self.mus = [0.0, 0.0] self.cov = [[1.0, 0.0], [0.0, 1.0]] dim = len(self.mus) - self.likelihood = MultiGaussian(self.mus, self.cov) + self.likelihood = MultiGaussian(self.mus, self.cov, xp=self.xp) # set priors out to +/- 5 sigma self.priors = bilby.core.prior.PriorDict() @@ -61,6 +69,7 @@ def setUp(self): grid_size=self.grid_size, likelihood=self.likelihood, save=True, + xp=self.xp, ) self.grid = grid @@ -151,6 +160,7 @@ def test_fail_grid_size(self): grid_size=2.3, likelihood=self.likelihood, save=True, + xp=self.xp, ) def test_mesh_grid(self): @@ -165,7 +175,8 @@ def test_grid_integer_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual(tuple(n_points), grid.mesh_grid[0].shape) @@ -179,7 +190,8 @@ def test_grid_dict_points(self): outdir="outdir", priors=self.priors, grid_size=n_points, - likelihood=self.likelihood + likelihood=self.likelihood, + xp=self.xp, ) self.assertTupleEqual((n_points["x0"], n_points["x1"]), grid.mesh_grid[0].shape) self.assertEqual(grid.mesh_grid[0][0, 0], self.priors[self.grid.parameter_names[0]].minimum) @@ -196,6 +208,8 @@ def test_grid_from_array(self): priors=self.priors, grid_size=n_points, likelihood=self.likelihood, + xp=self.xp, + vectorized=True, ) self.assertTupleEqual((len(x0s), len(x1s)), grid.mesh_grid[0].shape) From 9d273564d429f7c5e8387fc67744a1789fcaa024 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sun, 1 Feb 2026 11:05:55 -0500 Subject: [PATCH 075/110] TEST: add more jax test coverage --- bilby/compat/patches.py | 14 +++ bilby/compat/utils.py | 35 ++++++- bilby/core/grid.py | 1 - bilby/core/prior/analytical.py | 143 ++++++++++++++-------------- bilby/core/prior/base.py | 60 ++++++++---- bilby/core/prior/conditional.py | 31 ++++-- bilby/core/prior/dict.py | 71 ++++++++------ bilby/core/prior/interpolated.py | 8 +- bilby/core/prior/joint.py | 97 +++++++++---------- bilby/core/prior/slabspike.py | 22 +++-- bilby/core/utils/calculus.py | 2 +- bilby/gw/prior.py | 66 +++++++------ test/conftest.py | 3 + test/core/grid_test.py | 1 - test/core/prior/analytical_test.py | 132 ++++++++++++++++--------- test/core/prior/base_test.py | 26 ++++- test/core/prior/conditional_test.py | 24 ++++- test/core/prior/dict_test.py | 50 ++++++---- test/core/prior/prior_test.py | 120 ++++++++++++++--------- test/core/prior/slabspike_test.py | 64 +++++++------ test/core/result_test.py | 14 ++- test/gw/prior_test.py | 17 +++- 22 files changed, 623 insertions(+), 378 deletions(-) diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py index 02cbc3394..7c497a24e 100644 --- a/bilby/compat/patches.py +++ b/bilby/compat/patches.py @@ -1,4 +1,5 @@ import array_api_compat as aac +import numpy as np from .utils import BackendNotImplementedError @@ -35,3 +36,16 @@ def multivariate_logpdf(xp, mean, cov): else: raise BackendNotImplementedError return logpdf + + +def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=None): + if xp is None: + xp = a.__array_namespace__() + + if "jax" in xp.__name__: + # the scipy version of logsumexp cannot be vmapped + from jax.scipy.special import logsumexp as lse + else: + from scipy.special import logsumexp as lse + + return lse(a=a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 4b099969b..5e00f8969 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -43,14 +43,39 @@ def promote_to_array(args, backend, skip=None): return args -def xp_wrap(func): +def xp_wrap(func, no_xp=False): + """ + A decorator that will figure out the array module from the input + arguments and pass it to the function as the 'xp' keyword argument. - def wrapped(self, *args, **kwargs): - if "xp" not in kwargs: + Parameters + ========== + func: function + The function to be decorated. + no_xp: bool + If True, the decorator will not attempt to add the 'xp' keyword + argument and so the wrapper is a no-op. + + Returns + ======= + function + The decorated function. + """ + + def wrapped(self, *args, xp=None, **kwargs): + if not no_xp and xp is None: try: - kwargs["xp"] = array_module(*args) + if len(args) > 0: + array_module = array_namespace(*args) + elif len(kwargs) > 0: + array_module = array_namespace(*kwargs.values()) + else: + array_module = np + kwargs["xp"] = array_module except TypeError: - pass + kwargs["xp"] = np + elif not no_xp: + kwargs["xp"] = xp return func(self, *args, **kwargs) return wrapped diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 2c09f75bd..de2972b10 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -208,7 +208,6 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): places = self.sample_points[name] xp = log_array.__array_namespace__() - print(xp) if len(places) > 1: dx = xp.diff(places) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index ca48e3393..118e36ce8 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -45,7 +45,7 @@ def __init__(self, peak, name=None, latex_label=None, unit=None): self._is_fixed = True self.least_recently_sampled = peak - def rescale(self, val): + def rescale(self, val, *, xp=None): """Rescale everything to the peak with the correct shape. Parameters @@ -58,7 +58,7 @@ def rescale(self, val): """ return self.peak * val ** 0 - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -73,7 +73,7 @@ def prob(self, val): at_peak = (val == self.peak) return at_peak * 1.0 - def cdf(self, val): + def cdf(self, val, *, xp=None): return 1.0 * (val > self.peak) @@ -106,7 +106,7 @@ def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, self.alpha = alpha @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -128,7 +128,7 @@ def rescale(self, val, *, xp=np): (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -147,7 +147,7 @@ def prob(self, val, *, xp=np): self.minimum ** (1 + self.alpha))) * self.is_in_prior_range(val) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -172,7 +172,7 @@ def ln_prob(self, val, *, xp=np): return ln_p + ln_in_range @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): if self.alpha == -1: with np.errstate(invalid="ignore"): _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) @@ -210,7 +210,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary) - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -227,7 +227,7 @@ def rescale(self, val): """ return self.minimum + val * (self.maximum - self.minimum) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -241,7 +241,7 @@ def prob(self, val): return ((val >= self.minimum) & (val <= self.maximum)) / (self.maximum - self.minimum) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val Parameters @@ -256,7 +256,7 @@ def ln_prob(self, val, *, xp=np): return xp.log(self.prob(val)) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): _cdf = (val - self.minimum) / (self.maximum - self.minimum) return xp.clip(_cdf, 0, 1) @@ -319,7 +319,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, boundary=boundary) @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the power-law prior. @@ -337,7 +337,7 @@ def rescale(self, val, *, xp=np): return xp.sign(2 * val - 1) * self.minimum * xp.exp(abs(2 * val - 1) * xp.log(self.maximum / self.minimum)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -353,7 +353,7 @@ def prob(self, val, *, xp=np): self.is_in_prior_range(val)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the logarithmic prior probability of val Parameters @@ -368,7 +368,7 @@ def ln_prob(self, val, *, xp=np): return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(self.maximum / self.minimum))) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): asymmetric = xp.log(abs(val) / self.minimum) / xp.log(self.maximum / self.minimum) return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) @@ -398,7 +398,7 @@ def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, latex_label=latex_label, unit=unit, boundary=boundary) @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in cosine prior. @@ -408,7 +408,7 @@ def rescale(self, val, *, xp=np): return xp.arcsin(val / norm + xp.sin(self.minimum)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [-pi/2, pi/2]. Parameters @@ -422,7 +422,7 @@ def prob(self, val, *, xp=np): return xp.cos(val) / 2 * self.is_in_prior_range(val) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): _cdf = ( (xp.sin(val) - xp.sin(self.minimum)) / (xp.sin(self.maximum) - xp.sin(self.minimum)) @@ -458,7 +458,7 @@ def __init__(self, minimum=0, maximum=np.pi, name=None, latex_label=latex_label, unit=unit, boundary=boundary) @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to a uniform in sine prior. @@ -468,7 +468,7 @@ def rescale(self, val, *, xp=np): return xp.arccos(xp.cos(self.minimum) - val / norm) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Defined over [0, pi]. Parameters @@ -482,7 +482,7 @@ def prob(self, val, *, xp=np): return xp.sin(val) / 2 * self.is_in_prior_range(val) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): _cdf = ( (xp.cos(val) - xp.cos(self.minimum)) / (xp.cos(self.maximum) - xp.cos(self.minimum)) @@ -518,7 +518,7 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.sigma = sigma @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gaussian prior. @@ -535,7 +535,7 @@ def rescale(self, val, *, xp=np): return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -549,7 +549,7 @@ def prob(self, val, *, xp=np): return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Parameters @@ -562,7 +562,7 @@ def ln_prob(self, val, *, xp=np): """ return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(2 * np.pi * self.sigma ** 2)) - def cdf(self, val): + def cdf(self, val, *, xp=None): return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 @@ -614,7 +614,7 @@ def normalisation(self): (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate truncated Gaussian prior. @@ -630,7 +630,7 @@ def rescale(self, val, *, xp=np): (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -644,7 +644,7 @@ def prob(self, val, *, xp=np): return xp.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ / self.sigma / self.normalisation * self.is_in_prior_range(val) - def cdf(self, val): + def cdf(self, val, *, xp=None): _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation _cdf *= val >= self.minimum @@ -714,7 +714,7 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N self.sigma = sigma @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate LogNormal prior. @@ -727,7 +727,7 @@ def rescale(self, val, *, xp=np): return xp.exp(self.mu + (2 * self.sigma ** 2)**0.5 * erfinv(2 * val - 1)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Returns the prior probability of val. Parameters @@ -741,7 +741,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -759,7 +759,7 @@ def ln_prob(self, val, *, xp=np): ) + xp.log(val > self.minimum), nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): with np.errstate(divide="ignore"): return 0.5 + erf( (xp.log(xp.maximum(val, self.minimum)) - self.mu) / self.sigma / np.sqrt(2) @@ -792,7 +792,7 @@ def __init__(self, mu, name=None, latex_label=None, unit=None, boundary=None): self.mu = mu @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Exponential prior. @@ -802,7 +802,7 @@ def rescale(self, val, *, xp=np): return -self.mu * xp.log1p(-val) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -816,7 +816,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -831,7 +831,7 @@ def ln_prob(self, val, *, xp=np): return -val / self.mu - xp.log(self.mu) + xp.log(val >= self.minimum) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): with np.errstate(divide="ignore", invalid="ignore", over="ignore"): return xp.maximum(1. - xp.exp(-val / self.mu), 0) @@ -871,7 +871,7 @@ def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, self.scale = scale @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Student's t-prior. @@ -889,7 +889,7 @@ def rescale(self, val, *, xp=np): ) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -903,7 +903,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -920,7 +920,7 @@ def ln_prob(self, val, *, xp=np): * xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) ) - def cdf(self, val): + def cdf(self, val, *, xp=None): return stdtr(self.df, (val - self.mu) / self.scale) @@ -963,7 +963,7 @@ def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, self.beta = beta @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Beta prior. @@ -979,7 +979,7 @@ def rescale(self, val, *, xp=np): ) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -993,7 +993,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val, xp=xp)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1009,7 +1009,7 @@ def ln_prob(self, val, *, xp=np): return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): return xp.nan_to_num( betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum)) ) + (val > self.maximum) @@ -1045,7 +1045,7 @@ def __init__(self, mu, scale, name=None, latex_label=None, unit=None, boundary=N self.scale = scale @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Logistic prior. @@ -1056,7 +1056,7 @@ def rescale(self, val, *, xp=np): return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1070,7 +1070,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1086,7 +1086,7 @@ def ln_prob(self, val, *, xp=np): 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(self.scale) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): return 1. / (1. + xp.exp(-(val - self.mu) / self.scale)) @@ -1120,7 +1120,7 @@ def __init__(self, alpha, beta, name=None, latex_label=None, unit=None, boundary self.beta = beta @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Cauchy prior. @@ -1130,7 +1130,7 @@ def rescale(self, val, *, xp=np): with np.errstate(divide="ignore", invalid="ignore"): return rescaled - xp.log(val < 1) + xp.log(val > 0) - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1144,7 +1144,7 @@ def prob(self, val): return 1. / self.beta / np.pi / (1. + ((val - self.alpha) / self.beta) ** 2) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1158,7 +1158,7 @@ def ln_prob(self, val, *, xp=np): return - xp.log(self.beta * np.pi) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): return 0.5 + xp.arctan((val - self.alpha) / self.beta) / np.pi @@ -1197,7 +1197,7 @@ def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary self.theta = theta @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Gamma prior. @@ -1206,7 +1206,7 @@ def rescale(self, val, *, xp=np): return xp.asarray(gammaincinv(self.k, val)) * self.theta @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1220,7 +1220,7 @@ def prob(self, val, *, xp=np): return xp.exp(self.ln_prob(val)) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Returns the log prior probability of val. Parameters @@ -1239,7 +1239,7 @@ def ln_prob(self, val, *, xp=np): return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=xp.inf) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): return gammainc(xp.asarray(self.k), xp.maximum(val, self.minimum) / self.theta) @@ -1331,7 +1331,7 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, self.expr = xp.exp(self.r) @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior. @@ -1352,7 +1352,7 @@ def rescale(self, val, *, xp=np): return -self.sigma * xp.log(xp.maximum(inv, 0)) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1370,7 +1370,7 @@ def prob(self, val, *, xp=np): ) @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the log prior probability of val. Parameters @@ -1384,7 +1384,7 @@ def ln_prob(self, val, *, xp=np): return xp.log(self.prob(val)) @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): """ Evaluate the CDF of the Fermi-Dirac distribution using a slightly modified form of Equation 23 of [1]_. @@ -1482,7 +1482,7 @@ def __init__( self._cumulative_weights_array = xp.insert(_cumulative_weights_array, 0, 0) @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the discrete-value prior. @@ -1501,7 +1501,7 @@ def rescale(self, val, *, xp=np): return xp.asarray(self._values_array)[index] @xp_wrap - def cdf(self, val, *, xp=np): + def cdf(self, val, *, xp=None): """Return the cumulative prior probability of val. Parameters @@ -1516,7 +1516,7 @@ def cdf(self, val, *, xp=np): return xp.asarray(self._cumulative_weights_array)[index] @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -1533,7 +1533,8 @@ def prob(self, val, *, xp=np): # turn 0d numpy array to scalar return p[()] - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, xp=None): """Return the logarithmic prior probability of val Parameters @@ -1545,12 +1546,12 @@ def ln_prob(self, val): float: """ - index = np.searchsorted(self._values_array, val) - index = np.clip(index, 0, self.nvalues - 1) - lnp = np.where( + index = xp.searchsorted(self._values_array, val) + index = xp.clip(index, 0, self.nvalues - 1) + lnp = xp.where( self._values_array[index] == val, self._lnweights_array[index], -np.inf ) - # turn 0d numpy array to scalar + # turn 0d array to scalar return lnp[()] @@ -1674,7 +1675,7 @@ def __init__(self, mode, minimum, maximum, name=None, latex_label=None, unit=Non self.rescaled_minimum = self.minimum - (self.minimum == self.mode) * self.scale self.rescaled_maximum = self.maximum + (self.maximum == self.mode) * self.scale - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from standard uniform to a triangular distribution. @@ -1696,7 +1697,7 @@ def rescale(self, val): self.maximum - above_mode ) * (val >= self.fractional_mode) - def prob(self, val): + def prob(self, val, *, xp=None): """ Return the prior probability of val @@ -1723,7 +1724,7 @@ def prob(self, val): ) return 2.0 * (between_minimum_and_mode + between_mode_and_maximum) / self.scale - def cdf(self, val): + def cdf(self, val, *, xp=None): """ Return the prior cumulative probability at val diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 5ca28de28..ea6ef5475 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -2,7 +2,9 @@ import json import os import re +import warnings +import array_api_compat as aac import numpy as np import scipy.stats @@ -14,6 +16,7 @@ get_dict_with_properties, WrappedInterp1d as interp1d, ) +from ...compat.utils import xp_wrap class Prior(object): @@ -57,6 +60,27 @@ def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, self.boundary = boundary self._is_fixed = False + def __init_subclass__(cls): + for method_name in ["prob", "ln_prob", "rescale", "cdf", "sample"]: + method = getattr(cls, method_name, None) + if method is not None: + from inspect import signature + + sig = signature(method) + if "xp" not in sig.parameters: + warnings.warn( + f"The method {method_name} of the prior class " + f"{cls.__name__} does not accept an 'xp' keyword " + "argument. This may cause some behaviour to fail. " + "Please see the bilby documentation for more " + "information: https://bilby-dev.github.io/bilby/" + "array_api.html" + f" {sig}", + DeprecationWarning, + stacklevel=2, + ) + setattr(cls, method_name, xp_wrap(method, no_xp=True)) + def __call__(self): """Overrides the __call__ special method. Calls the sample method. @@ -106,7 +130,7 @@ def __eq__(self, other): for key in this_dict: if key == "least_recently_sampled": continue - if isinstance(this_dict[key], np.ndarray): + if aac.is_array_api_obj(this_dict[key]): if not np.array_equal(this_dict[key], other_dict[key]): return False elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): @@ -116,7 +140,7 @@ def __eq__(self, other): return False return True - def sample(self, size=None): + def sample(self, size=None, *, xp=np): """Draw a sample from the prior Parameters @@ -131,10 +155,12 @@ def sample(self, size=None): """ from ..utils import random - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size)) + self.least_recently_sampled = self.rescale( + xp.array(random.rng.uniform(0, 1, size)) + ) return self.least_recently_sampled - def rescale(self, val): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -152,7 +178,7 @@ def rescale(self, val): """ return None - def prob(self, val): + def prob(self, val, *, xp=None): """Return the prior probability of val, this should be overwritten Parameters @@ -166,24 +192,22 @@ def prob(self, val): """ return np.nan - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Generic method to calculate CDF, can be overwritten in subclass """ from scipy.integrate import cumulative_trapezoid if np.any(np.isinf([self.minimum, self.maximum])): raise ValueError( "Unable to use the generic CDF calculation for priors with" "infinite support") - x = np.linspace(self.minimum, self.maximum, 1000) - pdf = self.prob(x) + x = xp.linspace(self.minimum, self.maximum, 1000) + pdf = self.prob(x, xp=xp) cdf = cumulative_trapezoid(pdf, x, initial=0) - interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, - fill_value=(0, 1)) - output = interp(val) - if isinstance(val, (int, float)): - output = float(output) - return output - - def ln_prob(self, val): + output = xp.interp(val, x, cdf / cdf[-1], left=0, right=1) + return output[()] + + @xp_wrap + def ln_prob(self, val, *, xp=None): """Return the prior ln probability of val, this should be overwritten Parameters @@ -196,7 +220,7 @@ def ln_prob(self, val): """ with np.errstate(divide='ignore'): - return np.log(self.prob(val)) + return xp.log(self.prob(val, xp=xp)) def is_in_prior_range(self, val): """Returns True if val is in the prior boundaries, zero otherwise @@ -473,7 +497,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, latex_label=latex_label, unit=unit) self._is_fixed = True - def prob(self, val): + def prob(self, val, *, xp=None): return (val > self.minimum) & (val < self.maximum) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index d0c7191a4..c221939b8 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -1,9 +1,12 @@ +import numpy as np + from .base import Prior, PriorException from .interpolated import Interped from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac from ..utils import infer_args_from_method, infer_parameters_from_function +from ...compat.utils import xp_wrap def conditional_prior_factory(prior_class): @@ -59,7 +62,7 @@ def condition_func(reference_params, y): self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) - def sample(self, size=None, **required_variables): + def sample(self, size=None, *, xp=np, **required_variables): """Draw a sample from the prior Parameters @@ -76,10 +79,15 @@ def sample(self, size=None, **required_variables): """ from ..utils import random - self.least_recently_sampled = self.rescale(random.rng.uniform(0, 1, size), **required_variables) + self.least_recently_sampled = self.rescale( + xp.array(random.rng.uniform(0, 1, size)), + xp=xp, + **required_variables, + ) return self.least_recently_sampled - def rescale(self, val, **required_variables): + @xp_wrap + def rescale(self, val, *, xp=None, **required_variables): """ 'Rescale' a sample from the unit line element to the prior. @@ -93,9 +101,10 @@ def rescale(self, val, **required_variables): """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).rescale(val) + return super(ConditionalPrior, self).rescale(val, xp=xp) - def prob(self, val, **required_variables): + @xp_wrap + def prob(self, val, *, xp=None, **required_variables): """Return the prior probability of val. Parameters @@ -111,9 +120,10 @@ def prob(self, val, **required_variables): float: Prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).prob(val) + return super(ConditionalPrior, self).prob(val, xp=xp) - def ln_prob(self, val, **required_variables): + @xp_wrap + def ln_prob(self, val, *, xp=None, **required_variables): """Return the natural log prior probability of val. Parameters @@ -129,9 +139,10 @@ def ln_prob(self, val, **required_variables): float: Natural log prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).ln_prob(val) + return super(ConditionalPrior, self).ln_prob(val, xp=xp) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=None, **required_variables): """Return the cdf of val. Parameters @@ -147,7 +158,7 @@ def cdf(self, val, **required_variables): float: CDF of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).cdf(val) + return super(ConditionalPrior, self).cdf(val, xp=xp) def update_conditions(self, **required_variables): """ diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index bef81b1ef..cad562d16 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -16,7 +16,7 @@ BilbyJsonEncoder, decode_bilby_json, ) -from ...compat.utils import array_module +from ...compat.utils import array_module, xp_wrap class PriorDict(dict): @@ -353,7 +353,7 @@ def fill_priors(self, likelihood=None, default_priors_file=None): for key in self: self.test_redundancy(key) - def sample(self, size=None): + def sample(self, size=None, *, xp=np): """Draw samples from the prior set Parameters @@ -365,9 +365,9 @@ def sample(self, size=None): ======= dict: Dictionary of the samples """ - return self.sample_subset_constrained(keys=list(self.keys()), size=size) + return self.sample_subset_constrained(keys=list(self.keys()), size=size, xp=xp) - def sample_subset_constrained_as_array(self, keys=iter([]), size=None): + def sample_subset_constrained_as_array(self, keys=iter([]), size=None, *, xp=np): """Return an array of samples Parameters @@ -382,12 +382,12 @@ def sample_subset_constrained_as_array(self, keys=iter([]), size=None): array: array_like An array of shape (len(key), size) of the samples (ordered by keys) """ - samples_dict = self.sample_subset_constrained(keys=keys, size=size) - samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()} + samples_dict = self.sample_subset_constrained(keys=keys, size=size, xp=xp) + samples_dict = {key: xp.atleast_1d(val) for key, val in samples_dict.items()} samples_list = [samples_dict[key] for key in keys] - return np.array(samples_list) + return xp.array(samples_list) - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, xp=np): """Draw samples from the prior set for parameters which are not a DeltaFunction Parameters @@ -407,7 +407,7 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Constraint): continue elif isinstance(self[key], Prior): - samples[key] = self[key].sample(size=size) + samples[key] = self[key].sample(size=size, xp=xp) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -430,7 +430,7 @@ def fixed_keys(self): def constraint_keys(self): return [k for k, p in self.items() if isinstance(p, Constraint)] - def sample_subset_constrained(self, keys=iter([]), size=None): + def sample_subset_constrained(self, keys=iter([]), size=None, *, xp=np): """ Sample a subset of priors while ensuring constraints are satisfied. @@ -446,7 +446,7 @@ def sample_subset_constrained(self, keys=iter([]), size=None): dict: Dictionary of valid samples. """ if not any(isinstance(self[key], Constraint) for key in self): - return self.sample_subset(keys=keys, size=size) + return self.sample_subset(keys=keys, size=size, xp=xp) efficiency_warning_was_issued = False @@ -462,7 +462,7 @@ def check_efficiency(n_tested, n_valid): n_tested_samples, n_valid_samples = 0, 0 if size is None or size == 1: while True: - sample = self.sample_subset(keys=keys, size=size) + sample = self.sample_subset(keys=keys, size=size, xp=xp) is_valid = self.evaluate_constraints(sample) n_tested_samples += 1 n_valid_samples += int(is_valid) @@ -477,17 +477,17 @@ def check_efficiency(n_tested, n_valid): all_samples = {key: np.array([]) for key in keys} _first_key = list(all_samples.keys())[0] while len(all_samples[_first_key]) < needed: - samples = self.sample_subset(keys=keys, size=needed) + samples = self.sample_subset(keys=keys, size=needed, xp=xp) keep = np.array(self.evaluate_constraints(samples), dtype=bool) for key in keys: - all_samples[key] = np.hstack( + all_samples[key] = xp.hstack( [all_samples[key], samples[key][keep].flatten()] ) n_tested_samples += needed n_valid_samples += np.sum(keep) check_efficiency(n_tested_samples, n_valid_samples) all_samples = { - key: np.reshape(all_samples[key][:needed], size) for key in keys + key: xp.reshape(all_samples[key][:needed], size) for key in keys } return all_samples @@ -512,15 +512,15 @@ def normalize_constraint_factor( self._cached_normalizations[keys] = factor_rounded return factor_rounded - def _estimate_normalization(self, keys, min_accept, sampling_chunk): - samples = self.sample_subset(keys=keys, size=sampling_chunk) + def _estimate_normalization(self, keys, min_accept, sampling_chunk, *, xp=np): + samples = self.sample_subset(keys=keys, size=sampling_chunk, xp=xp) keep = np.atleast_1d(self.evaluate_constraints(samples)) if len(keep) == 1: self._cached_normalizations[keys] = 1 return 1 all_samples = {key: np.array([]) for key in keys} while np.count_nonzero(keep) < min_accept: - samples = self.sample_subset(keys=keys, size=sampling_chunk) + samples = self.sample_subset(keys=keys, size=sampling_chunk, xp=xp) for key in samples: all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()]) keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) @@ -610,7 +610,8 @@ def check_ln_prob(self, sample, ln_prob, normalized=True): constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) return constrained_ln_prob - def cdf(self, sample): + @xp_wrap + def cdf(self, sample, *, xp=None): """Evaluate the cumulative distribution function at the provided points Parameters @@ -624,10 +625,10 @@ def cdf(self, sample): """ return sample.__class__( - {key: self[key].cdf(sample) for key, sample in sample.items()} + {key: self[key].cdf(sample, xp=xp) for key, sample in sample.items()} ) - def rescale(self, keys, theta): + def rescale(self, keys, theta, *, xp=None): """Rescale samples from unit cube to prior Parameters @@ -643,9 +644,10 @@ def rescale(self, keys, theta): """ if isinstance(theta, {}.values().__class__): theta = list(theta) - xp = array_module(theta) + if xp is None: + xp = array_module(theta) - return xp.asarray([self[key].rescale(sample) for key, sample in zip(keys, theta)]) + return xp.asarray([self[key].rescale(sample, xp=xp) for key, sample in zip(keys, theta)]) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -745,7 +747,7 @@ def _check_conditions_resolved(self, key, sampled_keys): conditions_resolved = False return conditions_resolved - def sample_subset(self, keys=iter([]), size=None): + def sample_subset(self, keys=iter([]), size=None, *, xp=np): self.convert_floats_to_delta_functions() add_delta_keys = [ key @@ -765,7 +767,9 @@ def sample_subset(self, keys=iter([]), size=None): if isinstance(self[key], Prior): try: samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + size=size, + xp=xp, + **subset_dict.get_required_variables(key), ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) @@ -776,7 +780,10 @@ def sample_subset(self, keys=iter([]), size=None): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = subset_dict[key].sample( + **rvars, + xp=xp, + ) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -798,7 +805,8 @@ def get_required_variables(self, key): for k in getattr(self[key], "required_variables", []) } - def prob(self, sample, **kwargs): + @xp_wrap + def prob(self, sample, *, xp=None, **kwargs): """ Parameters @@ -814,9 +822,9 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - xp = array_module(sample.values()) + print(sample, xp) res = xp.asarray([ - self[key].prob(sample[key], **self.get_required_variables(key)) + self[key].prob(sample[key], **self.get_required_variables(key), xp=xp) for key in sample ]) prob = xp.prod(res, **kwargs) @@ -852,10 +860,11 @@ def ln_prob(self, sample, axis=None, normalized=True): # return self.check_ln_prob(sample, ln_prob, # normalized=normalized) - def cdf(self, sample): + @xp_wrap + def cdf(self, sample, *, xp=None): self._prepare_evaluation(*zip(*sample.items())) res = { - key: self[key].cdf(sample[key], **self.get_required_variables(key)) + key: self[key].cdf(sample[key], **self.get_required_variables(key), xp=xp) for key in sample } return sample.__class__(res) diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index ab03809d1..1983877d7 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -66,7 +66,8 @@ def __eq__(self, other): return True return False - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val. Parameters @@ -79,11 +80,12 @@ def prob(self, val): """ return self.probability_density(val)[()] - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): return self.cumulative_distribution(val)[()] @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 06a740497..0e8e8abfa 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -1,5 +1,6 @@ import re +import array_api_extra as xpx import numpy as np import scipy.stats from scipy.special import erfinv @@ -173,13 +174,14 @@ def _split_repr(cls, string): kwargs[key.strip()] = arg return kwargs - def prob(self, samp): + @xp_wrap + def prob(self, samp, *, xp=None): """ Get the probability of a sample. For bounded priors the probability will not be properly normalised. """ - return np.exp(self.ln_prob(samp)) + return xp.exp(self.ln_prob(samp, xp=xp)) def _check_samp(self, value): """ @@ -217,7 +219,8 @@ def _check_samp(self, value): break return samp, outbounds - def ln_prob(self, value): + @xp_wrap + def ln_prob(self, value, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. @@ -231,14 +234,12 @@ def ln_prob(self, value): """ samp, outbounds = self._check_samp(value) - lnprob = -np.inf * np.ones(samp.shape[0]) - lnprob = self._ln_prob(samp, lnprob, outbounds) - if samp.shape[0] == 1: - return lnprob[0] - else: - return lnprob + lnprob = -np.inf * xp.ones(samp.shape[0]) + lnprob = self._ln_prob(samp, lnprob, outbounds, xp=xp) + return lnprob[()] - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Get the log-probability of a sample. For bounded priors the probability will not be properly normalised. **this method needs overwritten by child class** @@ -262,7 +263,7 @@ def _ln_prob(self, samp, lnprob, outbounds): """ return lnprob - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, xp=np, **kwargs): """ Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten @@ -274,14 +275,11 @@ def sample(self, size=1, **kwargs): if size is None: size = 1 - samps = self._sample(size=size, **kwargs) + samps = self._sample(size=size, xp=xp, **kwargs) for i, name in enumerate(self.names): - if size == 1: - self.current_sample[name] = samps[:, i].flatten()[0] - else: - self.current_sample[name] = samps[:, i].flatten() + self.current_sample[name] = samps[:, i].flatten()[()] - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): """ Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**) @@ -290,14 +288,14 @@ def _sample(self, size, **kwargs): size: int number of samples to generate, defaults to 1 """ - samps = np.zeros((size, len(self))) + samps = xp.zeros((size, len(self))) """ Here is where the subclass where overwrite sampling method """ return samps @xp_wrap - def rescale(self, value, *, xp=np, **kwargs): + def rescale(self, value, *, xp=None, **kwargs): """ Rescale from a unit hypercube to JointPriorDist. Note that no bounds are applied in the rescale function. (child classes need to @@ -614,8 +612,7 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): ) @xp_wrap - def _rescale(self, samp, *, xp=np, **kwargs): - print(samp, xp) + def _rescale(self, samp, *, xp=None, **kwargs): try: mode = kwargs["mode"] except KeyError: @@ -630,12 +627,12 @@ def _rescale(self, samp, *, xp=np, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * xp.einsum( + samp = xp.array(self.mus[mode]) + self.sigmas[mode] * xp.einsum( "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] ) return samp - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): try: mode = kwargs["mode"] except KeyError: @@ -677,18 +674,21 @@ def _sample(self, size, **kwargs): if not outbound: inbound = True - return samps + return xp.array(samps) - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): for j in range(samp.shape[0]): # loop over the modes and sum the probabilities for i in range(self.nmodes): # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + lnprob = xpx.at(lnprob, j).set( + xp.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + ) # set out-of-bounds values to -inf - lnprob[outbounds] = -np.inf + lnprob = xp.where(outbounds, -xp.inf, lnprob) return lnprob def __eq__(self, other): @@ -783,7 +783,7 @@ def maximum(self, maximum): self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) @xp_wrap - def rescale(self, val, *, xp=np, **kwargs): + def rescale(self, val, *, xp=None, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -808,7 +808,7 @@ def rescale(self, val, *, xp=np, **kwargs): else: return [] # return empty list - def sample(self, size=1, **kwargs): + def sample(self, size=1, *, xp=np, **kwargs): """ Draw a sample from the prior. @@ -833,7 +833,7 @@ def sample(self, size=1, **kwargs): if len(self.dist.current_sample) == 0: # generate a sample - self.dist.sample(size=size, **kwargs) + self.dist.sample(size=size, xp=xp, **kwargs) sample = self.dist.current_sample[self.name] @@ -846,7 +846,8 @@ def sample(self, size=1, **kwargs): self.least_recently_sampled = sample return sample - def ln_prob(self, val): + @xp_wrap + def ln_prob(self, val, *, xp=None): """ Return the natural logarithm of the prior probability. Note that this will not be correctly normalised if there are bounds on the @@ -868,25 +869,16 @@ def ln_prob(self, val): values = list(self.dist.requested_parameters.values()) # check for the same number of values for each parameter - for i in range(len(self.dist) - 1): - if isinstance(values[i], (list, np.ndarray)) or isinstance( - values[i + 1], (list, np.ndarray) - ): - if isinstance(values[i], (list, np.ndarray)) and isinstance( - values[i + 1], (list, np.ndarray) - ): - if len(values[i]) != len(values[i + 1]): - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) - else: - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) + shapes = set() + for v in values: + shapes.add(xp.array(v).shape) + if len(shapes) > 1: + raise ValueError( + "Each parameter must have the same " + "number of requested values." + ) - lnp = self.dist.ln_prob(np.asarray(values).T) + lnp = self.dist.ln_prob(xp.array(values).T) # reset the requested parameters self.dist.reset_request() @@ -905,9 +897,10 @@ def ln_prob(self, val): if len(val) == 1: return 0.0 else: - return np.zeros_like(val) + return xp.zeros_like(val) - def prob(self, val): + @xp_wrap + def prob(self, val, *, xp=None): """Return the prior probability of val Parameters @@ -921,7 +914,7 @@ def prob(self, val): the p value for the prior at given sample """ - return np.exp(self.ln_prob(val)) + return xp.exp(self.ln_prob(val, xp=xp)) class MultivariateGaussian(JointPrior): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index ff823a369..30ca16639 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -73,7 +73,7 @@ def _find_inverse_cdf_fraction_before_spike(self): return float(self.slab.cdf(self.spike_location)) * self.slab_fraction @xp_wrap - def rescale(self, val, *, xp=np): + def rescale(self, val, *, xp=None): """ 'Rescale' a sample from the unit line element to the prior. @@ -93,15 +93,20 @@ def rescale(self, val, *, xp=np): ) higher_indices = val >= (self.inverse_cdf_below_spike + self.spike_height) - slab_scaled = self._contracted_rescale(val - self.spike_height * higher_indices) + slab_scaled = self._contracted_rescale( + val - self.spike_height * higher_indices, xp=xp + ) + print(type(slab_scaled)) res = xp.select( [lower_indices | higher_indices, intermediate_indices], [slab_scaled, self.spike_location], ) + print(type(res)) return res - def _contracted_rescale(self, val): + @xp_wrap + def _contracted_rescale(self, val, *, xp=None): """ Contracted version of the rescale function that implements the `rescale` function on the pure slab part of the prior. @@ -115,10 +120,10 @@ def _contracted_rescale(self, val): ======= array_like: Associated prior value with input value. """ - return self.slab.rescale(val / self.slab_fraction) + return self.slab.rescale(val / self.slab_fraction, xp=xp) @xp_wrap - def prob(self, val, *, xp=np): + def prob(self, val, *, xp=None): """Return the prior probability of val. Returns np.inf for the spike location @@ -136,7 +141,7 @@ def prob(self, val, *, xp=np): return res @xp_wrap - def ln_prob(self, val, *, xp=np): + def ln_prob(self, val, *, xp=None): """Return the Log prior probability of val. Returns np.inf for the spike location @@ -153,7 +158,8 @@ def ln_prob(self, val, *, xp=np): res += xp.nan_to_num(xp.inf * (val == self.spike_location), posinf=xp.inf) return res - def cdf(self, val): + @xp_wrap + def cdf(self, val, *, xp=None): """ Return the CDF of the prior. This calls to the slab CDF and adds a discrete step at the spike location. @@ -167,6 +173,6 @@ def cdf(self, val): array_like: CDF value of val """ - res = self.slab.cdf(val) * self.slab_fraction + res = self.slab.cdf(val, xp=xp) * self.slab_fraction res += (val > self.spike_location) * self.spike_height return res diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 2cc2b6ae1..889e086b0 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -3,9 +3,9 @@ import array_api_compat as aac import numpy as np from scipy.interpolate import RectBivariateSpline, interp1d as _interp1d -from scipy.special import logsumexp from .log import logger +from ...compat.patches import logsumexp from ...compat.utils import array_module, xp_wrap diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 3258d37d1..3fe9ed242 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,6 +1,7 @@ import os import copy +import array_api_extra as xpx import numpy as np from scipy.integrate import cumulative_trapezoid, trapezoid, quad from scipy.interpolate import InterpolatedUnivariateSpline @@ -431,23 +432,24 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', def _integral(q): return -5. * q**(-1. / 5.) * hyp2f1(-2. / 5., -1. / 5., 4. / 5., -q) - def cdf(self, val): + def cdf(self, val, *, xp=np): return (self._integral(val) - self._integral(self.minimum)) / self.norm - def rescale(self, val): + @xp_wrap + def rescale(self, val, *, xp=None): if self.equal_mass: - val = 2 * np.minimum(val, 1 - val) + val = 2 * xp.minimum(val, 1 - val) return self.icdf(val) - def prob(self, val): + def prob(self, val, *, xp=np): in_prior = (val >= self.minimum) & (val <= self.maximum) with np.errstate(invalid="ignore"): prob = (1. + val)**(2. / 5.) / (val**(6. / 5.)) / self.norm * in_prior return prob - def ln_prob(self, val): + def ln_prob(self, val, *, xp=np): with np.errstate(divide="ignore"): - return np.log(self.prob(val)) + return np.log(self.prob(val, xp=xp)) class AlignedSpin(Interped): @@ -512,7 +514,7 @@ def integrand(aa, chi): after performing the integral over spin orientation using a delta function identity. """ - return a_prior.prob(aa, xp=np) * z_prior.prob(chi / aa, xp=np) / aa + return a_prior.prob(aa, xp=None) * z_prior.prob(chi / aa, xp=None) / aa self.num_interp = 10_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) @@ -619,7 +621,8 @@ def ln_prob(self, val, *, xp=np, **required_variables): with np.errstate(divide="ignore"): return xp.log(self.prob(val, **required_variables)) - def cdf(self, val, **required_variables): + @xp_wrap + def cdf(self, val, *, xp=np, **required_variables): r""" .. math:: \text{CDF}(\chi_\per) = N ln(1 + (\chi_\perp / \chi) ** 2) @@ -639,14 +642,15 @@ def cdf(self, val, **required_variables): """ self.update_conditions(**required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) - return np.maximum(np.minimum( + return xp.clip( (val >= self.minimum) * (val <= self.maximum) - * np.log(1 + (val / chi_aligned) ** 2) - / 2 / np.log(self._reference_maximum / chi_aligned) - , 1 - ), 0) + * xp.log(1 + (val / chi_aligned) ** 2) + / 2 / xp.log(self._reference_maximum / chi_aligned), + 0, + 1 + ) - def rescale(self, val, **required_variables): + def rescale(self, val, *, xp=np, **required_variables): r""" .. math:: \text{PPF}(\chi_\perp) = ((a_\max / \chi) ** (2x) - 1) ** 0.5 * \chi @@ -695,13 +699,13 @@ def __init__(self, minimum=-np.inf, maximum=np.inf): super().__init__(minimum=minimum, maximum=maximum, name=None, latex_label=None, unit=None) - def prob(self, val): + def prob(self, val, *, xp=np): """ Returns the result of the equation of state check in the conversion function. """ return val - def ln_prob(self, val): + def ln_prob(self, val, *, xp=np): if val: result = 0.0 @@ -1521,7 +1525,8 @@ def _check_imports(): raise ImportError("Must have healpy installed on this machine to use HealPixMapPrior") return healpy - def _rescale(self, samp, **kwargs): + @xp_wrap + def _rescale(self, samp, *, xp=np, **kwargs): """ Overwrites the _rescale method of BaseJoint Prior to rescale a single value from the unitcube onto two values (ra, dec) or 3 (ra, dec, dist) if distance is included @@ -1544,17 +1549,19 @@ def _rescale(self, samp, **kwargs): else: samp = samp[:, 0] pix_rescale = self.inverse_cdf(samp) - sample = np.empty((len(pix_rescale), 2)) - dist_samples = np.empty((len(pix_rescale))) + sample = xp.empty((len(pix_rescale), 2)) + dist_samples = xp.empty((len(pix_rescale))) for i, val in enumerate(pix_rescale): theta, ra = self.hp.pix2ang(self.nside, int(round(val))) dec = 0.5 * np.pi - theta - sample[i, :] = self.draw_from_pixel(ra, dec, int(round(val))) + sample = xpx.at(sample, i).set(xp.array(self.draw_from_pixel(ra, dec, int(round(val))))) if self.distance: self.update_distance(int(round(val))) - dist_samples[i] = self.distance_icdf(dist_samp[i]) + dist_samples = xpx.at(dist_samples, i).set( + xp.array(self.distance_icdf(dist_samp[i])) + ) if self.distance: - sample = np.vstack([sample[:, 0], sample[:, 1], dist_samples]) + sample = xp.vstack([sample[:, 0], sample[:, 1], dist_samples]) return sample.reshape((-1, self.num_vars)) def update_distance(self, pix_idx): @@ -1600,7 +1607,7 @@ def _check_norm(array): norm = np.finfo(array.dtype).eps return array / norm - def _sample(self, size, **kwargs): + def _sample(self, size, *, xp=np, **kwargs): """ Overwrites the _sample method of BaseJoint Prior. Picks a pixel value according to their probabilities, then uniformly samples ra, and decs that are contained in chosen pixel. If the PriorDist includes distance it then @@ -1631,7 +1638,7 @@ def _sample(self, size, **kwargs): sample[samp, :] = [ra_dec[0], ra_dec[1], dist] else: sample[samp, :] = self.draw_from_pixel(ra, dec, sample_pix[samp]) - return sample.reshape((-1, self.num_vars)) + return xp.array(sample.reshape((-1, self.num_vars))) def draw_distance(self, pix): """ @@ -1710,7 +1717,8 @@ def check_in_pixel(self, ra, dec, pix): pixel = self.hp.ang2pix(self.nside, theta, phi) return pix == pixel - def _ln_prob(self, samp, lnprob, outbounds): + @xp_wrap + def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): """ Overwrites the _lnprob method of BaseJoint Prior @@ -1736,11 +1744,13 @@ def _ln_prob(self, samp, lnprob, outbounds): phi, dec = samp[0] theta = 0.5 * np.pi - dec pixel = self.hp.ang2pix(self.nside, theta, phi) - lnprob[i] = np.log(self.prob[pixel] / self.pixel_area) + xpx.at(lnprob, i).set(xp.log(self.prob[pixel] / self.pixel_area)) if self.distance: self.update_distance(pixel) - lnprob[i] += np.log(self.distance_pdf(dist) * dist ** 2) - lnprob[outbounds] = -np.inf + lnprob = xpx.at(lnprob, i).set( + lnprob[i] + xp.log(self.distance_pdf(dist) * dist ** 2) + ) + lnprob = xp.where(outbounds, -np.inf, lnprob) return lnprob def __eq__(self, other): diff --git a/test/conftest.py b/test/conftest.py index 6efaa82e2..355c3074b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -37,7 +37,10 @@ def _xp(request): import numpy return numpy case "jax" | "jax.numpy": + import os import jax + + os.environ["SCIPY_ARRAY_API"] = "1" jax.config.update("jax_enable_x64", True) return jax.numpy case _: diff --git a/test/core/grid_test.py b/test/core/grid_test.py index bc17b9ce7..c61a80731 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -209,7 +209,6 @@ def test_grid_from_array(self): grid_size=n_points, likelihood=self.likelihood, xp=self.xp, - vectorized=True, ) self.assertTupleEqual((len(x0s), len(x1s)), grid.mesh_grid[0].shape) diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 12892aca1..8325be232 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -1,16 +1,19 @@ import unittest -import numpy as np import bilby +import numpy as np +import pytest +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDiscreteValuesPrior(unittest.TestCase): def test_single_sample(self): values = [1.1, 1.2, 1.3] discrete_value_prior = bilby.core.prior.DiscreteValues(values) in_prior = True for _ in range(1000): - s = discrete_value_prior.sample() + s = discrete_value_prior.sample(xp=self.xp) if s not in values: in_prior = False self.assertTrue(in_prior) @@ -20,7 +23,7 @@ def test_array_sample(self): nvalues = 4 discrete_value_prior = bilby.core.prior.DiscreteValues(values) N = 100000 - s = discrete_value_prior.sample(N) + s = discrete_value_prior.sample(N, xp=self.xp) zeros = np.sum(s == 1.0) ones = np.sum(s == 1.1) twos = np.sum(s == 1.2) @@ -35,21 +38,25 @@ def test_single_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.array(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.array(200)), 0) def test_single_probability_unsorted(self): N = 3 values = [1.1, 300, 2.2] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(1.1), 1 / N) - self.assertEqual(discrete_value_prior.prob(2.2), 1 / N) - self.assertEqual(discrete_value_prior.prob(300.0), 1 / N) - self.assertEqual(discrete_value_prior.prob(0.5), 0) - self.assertEqual(discrete_value_prior.prob(200), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.array(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.array(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.array(200)), 0) + self.assertEqual( + discrete_value_prior.prob(self.xp.array(0.5)).__array_namespace__(), + self.xp, + ) def test_array_probability(self): N = 3 @@ -57,7 +64,7 @@ def test_array_probability(self): discrete_value_prior = bilby.core.prior.DiscreteValues(values) self.assertTrue( np.all( - discrete_value_prior.prob([1.1, 2.2, 2.2, 300.0, 200.0]) + discrete_value_prior.prob(self.xp.array([1.1, 2.2, 2.2, 300.0, 200.0])) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) ) ) @@ -66,10 +73,14 @@ def test_single_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.ln_prob(1.1), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(2.2), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(300), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(150), -np.inf) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(1.1)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(2.2)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(300)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(150)), -np.inf) + self.assertEqual( + discrete_value_prior.ln_prob(self.xp.array(0.5)).__array_namespace__(), + self.xp, + ) def test_array_lnprobability(self): N = 3 @@ -77,18 +88,20 @@ def test_array_lnprobability(self): discrete_value_prior = bilby.core.prior.DiscreteValues(values) self.assertTrue( np.all( - discrete_value_prior.ln_prob([1.1, 2.2, 2.2, 300, 150]) + discrete_value_prior.ln_prob(self.xp.array([1.1, 2.2, 2.2, 300, 150])) == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) ) ) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.Categorical(3) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(xp=self.xp) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -97,7 +110,7 @@ def test_array_sample(self): ncat = 4 categorical_prior = bilby.core.prior.Categorical(ncat) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, xp=self.xp) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) @@ -107,41 +120,57 @@ def test_array_sample(self): self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertEqual(s.__array_namespace__(), self.xp) def test_single_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.prob(0), 1 / N) - self.assertEqual(categorical_prior.prob(1), 1 / N) - self.assertEqual(categorical_prior.prob(2), 1 / N) - self.assertEqual(categorical_prior.prob(0.5), 0) + self.assertEqual(categorical_prior.prob(self.xp.array(0)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.array(1)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.array(2)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.array(0.5)), 0) + self.assertEqual( + categorical_prior.prob(self.xp.array(0.5)).__array_namespace__(), + self.xp, + ) def test_array_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.prob([0, 1, 1, 2, 3]) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]))) + self.assertTrue(np.all( + categorical_prior.prob(self.xp.array([0, 1, 1, 2, 3])) + == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) + )) def test_single_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.ln_prob(0), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(1), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(2), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(0.5), -np.inf) + self.assertEqual(categorical_prior.ln_prob(self.xp.array(0)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.array(1)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.array(2)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.array(0.5)), -np.inf) + self.assertEqual( + categorical_prior.ln_prob(self.xp.array(0.5)).__array_namespace__(), + self.xp, + ) def test_array_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.ln_prob([0, 1, 1, 2, 3]) == np.array( - [-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]))) + self.assertTrue(np.all( + categorical_prior.ln_prob(self.xp.array([0, 1, 1, 2, 3])) + == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) + )) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWeightedCategoricalPrior(unittest.TestCase): def test_single_sample(self): categorical_prior = bilby.core.prior.WeightedCategorical(3, [1, 2, 3]) in_prior = True for _ in range(1000): - s = categorical_prior.sample() + s = categorical_prior.sample(xp=self.xp) if s not in [0, 1, 2]: in_prior = False self.assertTrue(in_prior) @@ -157,39 +186,49 @@ def test_array_sample(self): weights = np.arange(1, ncat + 1) categorical_prior = bilby.core.prior.WeightedCategorical(ncat, weights=weights) N = 100000 - s = categorical_prior.sample(N) + s = categorical_prior.sample(N, xp=self.xp) cases = 0 - for i in categorical_prior.values: + for i in self.xp.array(categorical_prior.values): case = np.sum(s == i) cases += case self.assertAlmostEqual(case / N, categorical_prior.prob(i), places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(case / N, weights[i] / np.sum(weights), places=int(np.log10(np.sqrt(N)))) self.assertEqual(cases, N) + self.assertEqual(s.__array_namespace__(), self.xp) def test_single_probability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: + for i in self.xp.array(categorical_prior.values): self.assertEqual(categorical_prior.prob(i), weights[i] / np.sum(weights)) - self.assertEqual(categorical_prior.prob(0.5), 0) + prob = categorical_prior.prob(self.xp.array(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(prob.__array_namespace__(), self.xp) def test_array_probability(self): N = 3 - test_cases = [0, 1, 1, 2, 3] + test_cases = self.xp.array([0, 1, 1, 2, 3]) weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) probs = np.arange(1, N + 2) / np.sum(weights) probs[-1] = 0 - self.assertTrue(np.all(categorical_prior.prob(test_cases) == probs[test_cases])) + new = categorical_prior.prob(test_cases) + self.assertTrue(np.all(new == probs[test_cases])) + self.assertEqual(new.__array_namespace__(), self.xp) def test_single_lnprobability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in categorical_prior.values: - self.assertEqual(categorical_prior.ln_prob(i), np.log(weights[i] / np.sum(weights))) - self.assertEqual(categorical_prior.prob(0.5), 0) + for i in self.xp.array(categorical_prior.values): + self.assertEqual( + categorical_prior.ln_prob(self.xp.array(i)), + np.log(weights[i] / np.sum(weights)), + ) + prob = categorical_prior.prob(self.xp.array(0.5)) + self.assertEqual(prob, 0) + self.assertEqual(prob.__array_namespace__(), self.xp) def test_array_lnprobability(self): N = 3 @@ -200,7 +239,9 @@ def test_array_lnprobability(self): ln_probs = np.log(np.arange(1, N + 2) / np.sum(weights)) ln_probs[-1] = -np.inf - self.assertTrue(np.all(categorical_prior.ln_prob(test_cases) == ln_probs[test_cases])) + new = categorical_prior.ln_prob(self.xp.array(test_cases)) + self.assertTrue(np.all(new == ln_probs[test_cases])) + self.assertEqual(new.__array_namespace__(), self.xp) def test_cdf(self): """ @@ -213,11 +254,12 @@ def test_cdf(self): categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) sample = categorical_prior.sample(size=10) - original = np.asarray(sample) - new = np.array(categorical_prior.rescale( + original = self.xp.asarray(sample) + new = self.xp.array(categorical_prior.rescale( categorical_prior.cdf(sample) )) np.testing.assert_array_equal(original, new) + self.assertEqual(type(new), type(original)) if __name__ == "__main__": diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index c9b788732..d83c3edd3 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import numpy as np +import pytest import bilby @@ -56,7 +57,7 @@ def test_base_prob(self): self.assertTrue(np.isnan(self.prior.prob(5))) def test_base_ln_prob(self): - self.prior.prob = lambda val: val + self.prior.prob = lambda val, *, xp=None: val self.assertEqual(np.log(5), self.prior.ln_prob(5)) def test_is_in_prior(self): @@ -139,6 +140,8 @@ def test_prob_inside(self): self.assertEqual(1, self.prior.prob(0.5)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConstraintPriorNormalisation(unittest.TestCase): def setUp(self): self.priors = dict( @@ -154,7 +157,7 @@ def conversion_func(parameters): def test_prob_integrate_to_one(self): keys = ["a", "b", "c"] n_samples = 1000000 - samples = self.priors.sample_subset(keys=keys, size=n_samples) + samples = self.priors.sample_subset(keys=keys, size=n_samples, xp=self.xp) prob = self.priors.prob(samples, axis=0) dm1 = self.priors["a"].maximum - self.priors["a"].minimum dm2 = self.priors["b"].maximum - self.priors["b"].minimum @@ -169,5 +172,24 @@ def test_prob_integrate_to_one(self): self.assertAlmostEqual(1, integral, delta=7 * sigma_integral) +class TestPriorSubclassWithoutXpWarning(unittest.TestCase): + def test_custom_subclass_without_xp_issues_warning(self): + """Test that a custom prior subclass without xp parameter in rescale method issues a warning.""" + with pytest.warns( + DeprecationWarning, + match=r"rescale.*CustomPriorWithoutXp.*xp.*keyword argument", + ): + # Define a custom prior subclass that doesn't include xp in rescale method + class CustomPriorWithoutXp(bilby.core.prior.Prior): + def rescale(self, val): + """Custom rescale without xp parameter""" + return val * 2 + + prior = CustomPriorWithoutXp(name="custom_prior") + import jax.numpy as jnp + rescaled = prior.rescale(jnp.array([0.1, 0.2, 3])) + self.assertEqual(rescaled.__array_namespace__(), jnp) + + if __name__ == "__main__": unittest.main() diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index e7e5ec670..cafa0f73e 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pickle +import pytest import bilby @@ -172,6 +173,8 @@ def test_cond_prior_instantiation_no_boundary_prior(self): self.assertIsNone(prior.boundary) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConditionalPriorDict(unittest.TestCase): def setUp(self): def condition_func_1(reference_parameters, var_0): @@ -208,7 +211,12 @@ def condition_func_3(reference_parameters, var_1, var_2): self.conditional_priors_manually_set_items = ( bilby.core.prior.ConditionalPriorDict() ) - self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4) + self.test_sample = dict( + var_0=self.xp.array(0.7), + var_1=self.xp.array(0.6), + var_2=self.xp.array(0.5), + var_3=self.xp.array(0.4), + ) self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( var_0=self.prior_0, @@ -260,12 +268,14 @@ def test_conditional_keys_setting_items(self): ) def test_prob(self): - self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample)) + prob = self.conditional_priors.prob(sample=self.test_sample) + self.assertEqual(self.test_value, prob) + self.assertEqual(prob.__array_namespace__(), self.xp) def test_prob_illegal_conditions(self): del self.conditional_priors["var_0"] with self.assertRaises(bilby.core.prior.IllegalConditionsException): - self.conditional_priors.prob(sample=self.test_sample) + self.conditional_priors.prob(sample=self.test_sample, xp=self.xp) def test_ln_prob(self): self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample)) @@ -356,7 +366,7 @@ def test_rescale_with_joint_prior(self): res = priordict.rescale(keys=keys, theta=ref_variables) self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(res.__array_namespace__(), self.xp) # check conditional values are still as expected expected = [self.test_sample["var_0"]] @@ -447,6 +457,8 @@ def _tp_conditional_uniform(ref_params, period): prior.sample_subset(["tp"], 1000) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestDirichletPrior(unittest.TestCase): def setUp(self): @@ -456,6 +468,10 @@ def tearDown(self): if os.path.isdir("priors"): shutil.rmtree("priors") + def test_samples_correct_type(self): + samples = self.priors.sample(10, xp=self.xp) + self.assertEqual(samples["dirichlet_1"].__array_namespace__(), self.xp) + def test_samples_sum_to_less_than_one(self): """ Test that the samples sum to less than one as required for the diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 3970277e0..59d5d83bf 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import numpy as np +import pytest import bilby @@ -22,6 +23,8 @@ def __init__(self, names, bounds=None): setattr(bilby.core.prior, "FakeJointPriorDist", FakeJointPriorDist) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorDict(unittest.TestCase): def setUp(self): @@ -268,30 +271,40 @@ def test_dict_argument_is_not_string_or_dict(self): def test_sample_subset_correct_size(self): size = 7 samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size + keys=self.prior_set_from_dict.keys(), size=size, + xp=self.xp, ) self.assertEqual(len(self.prior_set_from_dict), len(samples)) for key in samples: self.assertEqual(size, len(samples[key])) + self.assertEqual(samples[key].__array_namespace__(), self.xp) def test_sample_subset_correct_size_when_non_priors_in_dict(self): self.prior_set_from_dict["asdf"] = "not_a_prior" samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys() + keys=self.prior_set_from_dict.keys(), + xp=self.xp, ) self.assertEqual(len(self.prior_set_from_dict) - 1, len(samples)) + for key in samples: + self.assertIsNotNone(samples[key].__array_namespace__(), self.xp) def test_sample_subset_with_actual_subset(self): size = 3 - samples = self.prior_set_from_dict.sample_subset(keys=["length"], size=size) - expected = dict(length=np.array([42.0, 42.0, 42.0])) + samples = self.prior_set_from_dict.sample_subset( + keys=["length"], size=size, xp=self.xp + ) + expected = dict(length=self.xp.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + self.assertEqual(samples["length"].__array_namespace__(), self.xp) def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] - out = self.prior_set_from_dict.sample_subset_constrained_as_array(keys, size) - self.assertTrue(isinstance(out, np.ndarray)) + out = self.prior_set_from_dict.sample_subset_constrained_as_array( + keys, size, xp=self.xp + ) + self.assertEqual(out.__array_namespace__(), self.xp) self.assertTrue(out.shape == (len(keys), size)) def test_sample_subset_constrained(self): @@ -312,7 +325,7 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples1 = priors1.sample_subset_constrained( - keys=list(priors1.keys()), size=N + keys=list(priors1.keys()), size=N, xp=self.xp ) self.assertEqual(len(priors1) - 1, len(samples1)) for key in samples1: @@ -325,7 +338,7 @@ def conversion_function(parameters): with patch("bilby.core.prior.logger.warning") as mock_warning: samples2 = priors2.sample_subset_constrained( - keys=list(priors2.keys()), size=N + keys=list(priors2.keys()), size=N, xp=self.xp ) self.assertEqual(len(priors2), len(samples2)) for key in samples2: @@ -336,27 +349,31 @@ def test_sample(self): size = 7 bilby.core.utils.random.seed(42) samples1 = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size + keys=self.prior_set_from_dict.keys(), size=size, xp=self.xp ) bilby.core.utils.random.seed(42) - samples2 = self.prior_set_from_dict.sample(size=size) + samples2 = self.prior_set_from_dict.sample(size=size, xp=self.xp) self.assertEqual(set(samples1.keys()), set(samples2.keys())) for key in samples1: self.assertTrue(np.array_equal(samples1[key], samples2[key])) - + self.assertEqual(samples1[key].__array_namespace__(), self.xp) + self.assertEqual(samples2[key].__array_namespace__(), self.xp) + def test_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob( samples["speed"] ) self.assertEqual(expected, self.prior_set_from_dict.prob(samples)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_ln_prob(self): - samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) + samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) expected = self.first_prior.ln_prob( samples["mass"] ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_rescale(self): theta = [0.5, 0.5, 0.5] @@ -380,13 +397,14 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ - sample = self.prior_set_from_dict.sample() - original = np.array(list(sample.values())) - new = np.array(self.prior_set_from_dict.rescale( + sample = self.prior_set_from_dict.sample(xp=self.xp) + original = self.xp.array(list(sample.values())) + new = self.xp.array(self.prior_set_from_dict.rescale( sample.keys(), self.prior_set_from_dict.cdf(sample=sample).values() )) self.assertLess(max(abs(original - new)), 1e-10) + self.assertEqual(new.__array_namespace__(), self.xp) def test_redundancy(self): for key in self.prior_set_from_dict.keys(): diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index bed42cf19..c3fa1e865 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -2,6 +2,7 @@ import unittest import numpy as np import os +import pytest import scipy.stats as ss from scipy.integrate import trapezoid @@ -26,6 +27,8 @@ ) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorClasses(unittest.TestCase): def setUp(self): # set multivariate Gaussian @@ -265,24 +268,27 @@ def test_minimum_rescaling(self): # and so the rescale function doesn't quite return the lower bound continue elif bilby.core.prior.JointPrior in prior.__class__.__mro__: - minimum_sample = prior.rescale(0) + minimum_sample = prior.rescale(self.xp.array(0)) if prior.dist.filled_rescale(): self.assertAlmostEqual(minimum_sample[0], prior.minimum) self.assertAlmostEqual(minimum_sample[1], prior.minimum) else: - minimum_sample = prior.rescale(0) + minimum_sample = prior.rescale(self.xp.array(0)) self.assertAlmostEqual(minimum_sample, prior.minimum) def test_maximum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: if bilby.core.prior.JointPrior in prior.__class__.__mro__: - maximum_sample = prior.rescale(0) + maximum_sample = prior.rescale(self.xp.array(0)) if prior.dist.filled_rescale(): self.assertAlmostEqual(maximum_sample[0], prior.maximum) self.assertAlmostEqual(maximum_sample[1], prior.maximum) + elif isinstance(prior, bilby.gw.prior.AlignedSpin): + maximum_sample = prior.rescale(self.xp.array(1)) + self.assertGreater(maximum_sample, 0.997) else: - maximum_sample = prior.rescale(1) + maximum_sample = prior.rescale(self.xp.array(1)) self.assertAlmostEqual(maximum_sample, prior.maximum) def test_many_sample_rescaling(self): @@ -291,20 +297,22 @@ def test_many_sample_rescaling(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) + many_samples = prior.rescale(self.xp.array(np.random.uniform(0, 1, 1000))) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): continue self.assertTrue( all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) ) + self.assertEqual(many_samples.__array_namespace__(), self.xp) def test_least_recently_sampled(self): for prior in self.priors: - least_recently_sampled_expected = prior.sample() + least_recently_sampled_expected = prior.sample(xp=self.xp) self.assertEqual( least_recently_sampled_expected, prior.least_recently_sampled ) + self.assertEqual(least_recently_sampled_expected.__array_namespace__(), self.xp) def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" @@ -312,10 +320,11 @@ def test_sampling_single(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - single_sample = prior.sample() + single_sample = prior.sample(xp=self.xp) self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) ) + self.assertEqual(single_sample.__array_namespace__(), self.xp) def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" @@ -323,17 +332,18 @@ def test_sampling_many(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.sample(5000) + many_samples = prior.sample(5000, xp=self.xp) self.assertTrue( (all(many_samples >= prior.minimum)) & (all(many_samples <= prior.maximum)) ) + self.assertEqual(many_samples.__array_namespace__(), self.xp) def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -349,7 +359,7 @@ def test_probability_below_domain(self): # SymmetricLogUniform has support down to -maximum continue if prior.minimum != -np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) if bilby.core.prior.JointPrior in prior.__class__.__mro__: @@ -360,31 +370,39 @@ def test_probability_below_domain(self): def test_least_recently_sampled_2(self): for prior in self.priors: - lrs = prior.sample() + lrs = prior.sample(xp=self.xp) self.assertEqual(lrs, prior.least_recently_sampled) + self.assertEqual(lrs.__array_namespace__(), self.xp) def test_prob_and_ln_prob(self): for prior in self.priors: - sample = prior.sample() + sample = prior.sample(xp=self.xp) if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa # due to the way that the Multivariate Gaussian prior must sequentially call # the prob and ln_prob functions, it must be ignored in this test. + lnprob = prior.ln_prob(sample) + prob = prior.prob(sample) + # lower precision for jax running tests with float32 self.assertAlmostEqual( - np.log(prior.prob(sample)), prior.ln_prob(sample), 12 + self.xp.log(prob), lnprob, 6 ) + self.assertEqual(lnprob.__array_namespace__(), self.xp) + self.assertEqual(prob.__array_namespace__(), self.xp) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: - samples = prior.sample(10) + samples = prior.sample(10, xp=self.xp) if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa ln_probs = prior.ln_prob(samples) probs = prior.prob(samples) for sample, logp, p in zip(samples, ln_probs, probs): self.assertAlmostEqual(prior.ln_prob(sample), logp) self.assertAlmostEqual(prior.prob(sample), p) + self.assertEqual(ln_probs.__array_namespace__(), self.xp) + self.assertEqual(probs.__array_namespace__(), self.xp) def test_cdf_is_inverse_of_rescaling(self): - domain = np.linspace(0, 1, 100) + domain = self.xp.linspace(0, 1, 100) threshold = 1e-9 for prior in self.priors: if ( @@ -392,6 +410,9 @@ def test_cdf_is_inverse_of_rescaling(self): or bilby.core.prior.JointPrior in prior.__class__.__mro__ ): continue + elif isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): rescaled = prior.rescale(domain) cdf_vals = prior.cdf(rescaled) @@ -399,15 +420,18 @@ def test_cdf_is_inverse_of_rescaling(self): cdf_vals_2 = prior.cdf(rescaled_2) self.assertTrue(np.array_equal(rescaled, rescaled_2)) max_difference = max(np.abs(cdf_vals - cdf_vals_2)) + for arr in [rescaled, rescaled_2, cdf_vals, cdf_vals_2]: + self.assertEqual(arr.__array_namespace__(), self.xp) else: rescaled = prior.rescale(domain) max_difference = max(np.abs(domain - prior.cdf(rescaled))) + self.assertEqual(rescaled.__array_namespace__(), self.xp) self.assertLess(max_difference, threshold) def test_cdf_one_above_domain(self): for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 ) self.assertTrue(all(prior.cdf(outside_domain) == 1)) @@ -423,7 +447,7 @@ def test_cdf_zero_below_domain(self): ): continue if prior.minimum != -np.inf: - outside_domain = np.linspace( + outside_domain = self.xp.linspace( prior.minimum - 1e4, prior.minimum - 1, 1000 ) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) @@ -578,7 +602,7 @@ def test_probability_in_domain(self): maximum = 1e5 else: maximum = prior.maximum - domain = np.linspace(minimum, maximum, 1000) + domain = self.xp.linspace(minimum, maximum, 1000) self.assertTrue(all(prior.prob(domain) >= 0)) def test_probability_surrounding_domain(self): @@ -591,7 +615,7 @@ def test_probability_surrounding_domain(self): # SymmetricLogUniform has support down to -maximum continue with np.errstate(invalid="ignore"): - surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) + surround_domain = self.xp.linspace(prior.minimum - 1, prior.maximum + 1, 1000) indomain = (surround_domain >= prior.minimum) | ( surround_domain <= prior.maximum ) @@ -645,11 +669,13 @@ def test_normalized(self): domain = np.linspace(prior.minimum, prior.maximum, 10000) elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): domain = prior.values - self.assertTrue(np.sum(prior.prob(domain)) == 1) + self.assertTrue(np.sum(prior.prob(self.xp.array(domain))) == 1) continue else: domain = np.linspace(prior.minimum, prior.maximum, 1000) - self.assertAlmostEqual(trapezoid(prior.prob(domain), domain), 1, 3) + probs = prior.prob(self.xp.array(domain)) + self.assertAlmostEqual(trapezoid(np.array(probs), domain), 1, 3) + self.assertEqual(probs.__array_namespace__(), self.xp) def test_accuracy(self): """Test that each of the priors' functions is calculated accurately, as compared to scipy's calculations""" @@ -744,11 +770,15 @@ def test_accuracy(self): bilby.core.prior.WeightedDiscreteValues, ) if isinstance(prior, (testTuple)): - np.testing.assert_almost_equal(prior.prob(domain), scipy_prob) - np.testing.assert_almost_equal(prior.ln_prob(domain), scipy_lnprob) - np.testing.assert_almost_equal(prior.cdf(domain), scipy_cdf) + print(prior) + np.testing.assert_almost_equal(prior.prob(self.xp.array(domain)), scipy_prob) + np.testing.assert_almost_equal(prior.ln_prob(self.xp.array(domain)), scipy_lnprob) + np.testing.assert_almost_equal(prior.cdf(self.xp.array(domain)), scipy_cdf) + if isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): + # JAX implementation of StudentT prior rescale is not accurate enough + continue np.testing.assert_almost_equal( - prior.rescale(rescale_domain), scipy_rescale + prior.rescale(self.xp.array(rescale_domain)), scipy_rescale ) def test_unit_setting(self): @@ -833,7 +863,7 @@ def test_set_maximum_setting(self): ): continue prior.maximum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(max(prior.sample(10000)) < prior.maximum) + self.assertTrue(max(prior.sample(10000, xp=self.xp)) < prior.maximum) def test_set_minimum_setting(self): for prior in self.priors: @@ -859,25 +889,25 @@ def test_set_minimum_setting(self): ): continue prior.minimum = (prior.maximum + prior.minimum) / 2 - self.assertTrue(min(prior.sample(10000)) > prior.minimum) - - def test_jax_methods(self): - import jax - - points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) - for prior in self.priors: - if bilby.core.prior.JointPrior in prior.__class__.__mro__: - continue - scaled = prior.rescale(points) - assert isinstance(scaled, jax.Array) - if isinstance(prior, bilby.core.prior.DeltaFunction): - continue - probs = prior.prob(scaled) - assert min(probs) > 0 - assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 - if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): - continue - assert max(abs(prior.cdf(scaled) - points)) < 1e-6 + self.assertTrue(min(prior.sample(10000, xp=self.xp)) > prior.minimum) + + # def test_jax_methods(self): + # import jax + + # points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) + # for prior in self.priors: + # if bilby.core.prior.JointPrior in prior.__class__.__mro__: + # continue + # scaled = prior.rescale(points) + # assert isinstance(scaled, jax.Array) + # if isinstance(prior, bilby.core.prior.DeltaFunction): + # continue + # probs = prior.prob(scaled) + # assert min(probs) > 0 + # assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 + # if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): + # continue + # assert max(abs(prior.cdf(scaled) - points)) < 1e-6 if __name__ == "__main__": diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index 8cb2fcf1d..501f3e39b 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -1,6 +1,9 @@ -import numpy as np import unittest +import array_api_compat as aac +import numpy as np +import pytest + import bilby from bilby.core.prior.slabspike import SlabSpikePrior from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \ @@ -60,12 +63,14 @@ def test_set_spike_height_domain_edge(self): self.prior.spike_height = 1 +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSlabSpikeClasses(unittest.TestCase): def setUp(self): self.minimum = 0.4 self.maximum = 2.4 - self.spike_loc = 1.5 + self.spike_loc = self.xp.array(1.5) self.spike_height = 0.3 self.slabs = [ @@ -80,15 +85,17 @@ def setUp(self): HalfGaussian(sigma=1), LogNormal(mu=1, sigma=2), Exponential(mu=2), - StudentT(df=2), Logistic(mu=2, scale=1), Cauchy(alpha=1, beta=2), Gamma(k=1, theta=1.), - ChiSquared(nu=2)] + ChiSquared(nu=2), + ] + if not aac.is_jax_namespace(self.xp): + StudentT(df=2), self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) for slab in self.slabs] - self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000) - self.test_nodes_infinite_support = np.linspace(-10, 10, 1000) + self.test_nodes_finite_support = self.xp.linspace(self.minimum, self.maximum, 1000) + self.test_nodes_infinite_support = self.xp.linspace(-10, 10, 1000) self.test_nodes = [self.test_nodes_finite_support if np.isinf(slab.minimum) or np.isinf(slab.maximum) else self.test_nodes_finite_support for slab in self.slabs] @@ -102,24 +109,12 @@ def tearDown(self): del self.test_nodes_finite_support del self.test_nodes_infinite_support - def test_jax_methods(self): - import jax - - points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) - for prior in self.slab_spikes: - scaled = prior.rescale(points) - assert isinstance(scaled, jax.Array) - if isinstance(prior, bilby.core.prior.DeltaFunction): - continue - probs = prior.prob(scaled) - assert min(probs) > 0 - assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 - def test_prob_on_slab(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): expected = slab.prob(test_nodes) * slab_spike.slab_fraction actual = slab_spike.prob(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_prob_on_spike(self): for slab_spike in self.slab_spikes: @@ -130,10 +125,13 @@ def test_ln_prob_on_slab(self): expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction) actual = slab_spike.ln_prob(test_nodes) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_ln_prob_on_spike(self): for slab_spike in self.slab_spikes: - self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc)) + expected = slab_spike.ln_prob(self.spike_loc) + self.assertEqual(np.inf, expected) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_inverse_cdf_below_spike_with_spike_at_minimum(self): for slab in self.slabs: @@ -156,19 +154,22 @@ def test_cdf_below_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction actual = slab_spike.cdf(self.spike_loc) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) - self.assertTrue(np.array_equal(expected, actual)) + np.testing.assert_allclose(expected, actual, rtol=1e-12) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_cdf_at_minimum(self): for slab_spike in self.slab_spikes: @@ -185,31 +186,36 @@ def test_cdf_at_maximum(self): def test_rescale_no_spike(self): for slab in self.slabs: slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum) - vals = np.linspace(0, 1, 1000) + vals = self.xp.linspace(0, 1, 1000) expected = slab.rescale(vals) actual = slab_spike.rescale(vals) - print(slab) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_rescale_below_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) + vals = self.xp.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) expected = slab.rescale(vals / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_rescale_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike, - slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000) - expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) + vals = self.xp.linspace( + slab_spike.inverse_cdf_below_spike, + slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000 + ) + expected = self.xp.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) def test_rescale_above_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) - expected = np.ones(len(vals)) * slab.rescale( + vals = self.xp.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) + expected = self.xp.ones(len(vals)) * slab.rescale( (vals - self.spike_height) / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + self.assertEqual(expected.__array_namespace__(), self.xp) diff --git a/test/core/result_test.py b/test/core/result_test.py index 23ba8e6b5..64fc9c14f 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -13,6 +13,8 @@ from bilby.core.utils import logger +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestJson(unittest.TestCase): def setUp(self): @@ -28,12 +30,12 @@ def test_list_encoding(self): self.assertTrue(np.all(data["x"] == decoded["x"])) def test_array_encoding(self): - data = dict(x=np.array([1, 2, 3.4])) + data = dict(x=self.xp.array([1, 2, 3.4])) encoded = json.dumps(data, cls=self.encoder) decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) self.assertEqual(type(data["x"]), type(decoded["x"])) - self.assertTrue(np.all(data["x"] == decoded["x"])) + self.assertTrue(self.xp.all(data["x"] == decoded["x"])) def test_complex_encoding(self): data = dict(x=1 + 3j) @@ -919,6 +921,8 @@ def test_reweight_different_likelihood_weights_correct(self): self.assertNotEqual(new.log_evidence, self.result.log_evidence) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestResultSaveAndRead(unittest.TestCase): @pytest.fixture(autouse=True) @@ -944,7 +948,11 @@ def setUp(self): search_parameter_keys=["x", "y"], fixed_parameter_keys=["c", "d"], priors=priors, - sampler_kwargs=dict(test="test", func=lambda x: x), + sampler_kwargs=dict( + test="test", + func=lambda x: x, + some_array=self.xp.ones((5, 5)), + ), injection_parameters=dict(x=0.5, y=0.5), meta_data=dict(test="test"), sampling_time=100.0, diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index 2d35986cf..022bc0a56 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -10,6 +10,7 @@ from scipy.stats import ks_2samp import matplotlib.pyplot as plt import pandas as pd +import pytest import bilby from bilby.core.prior import Uniform, Constraint @@ -557,14 +558,16 @@ def test_luminosity_distance_to_comoving_distance(self): self.assertEqual(new_prior.name, "comoving_distance") +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestAlignedSpin(unittest.TestCase): def setUp(self): pass def test_default_prior_matches_analytic(self): prior = bilby.gw.prior.AlignedSpin() - chis = np.linspace(-1, 1, 20) - analytic = -np.log(np.abs(chis)) / 2 + chis = self.xp.linspace(-1, 1, 20) + analytic = -self.xp.log(self.xp.abs(chis)) / 2 max_difference = max(abs(analytic - prior.prob(chis))) self.assertAlmostEqual(max_difference, 0, 2) @@ -572,12 +575,15 @@ def test_non_analytic_form_has_correct_statistics(self): a_prior = bilby.core.prior.TruncatedGaussian(mu=0, sigma=0.1, minimum=0, maximum=1) z_prior = bilby.core.prior.TruncatedGaussian(mu=0.4, sigma=0.2, minimum=-1, maximum=1) chi_prior = bilby.gw.prior.AlignedSpin(a_prior, z_prior) - chis = chi_prior.sample(100000) - alts = a_prior.sample(100000) * z_prior.sample(100000) + chis = chi_prior.sample(100000, xp=self.xp) + alts = a_prior.sample(100000, xp=self.xp) * z_prior.sample(100000, xp=self.xp) self.assertAlmostEqual(np.mean(chis), np.mean(alts), 2) self.assertAlmostEqual(np.std(chis), np.std(alts), 2) + self.assertEqual(chis.__array_namespace__(), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestConditionalChiUniformSpinMagnitude(unittest.TestCase): def setUp(self): @@ -588,9 +594,10 @@ def test_marginalized_prior_is_uniform(self): priors["a_1"] = bilby.gw.prior.ConditionalChiUniformSpinMagnitude( minimum=0.1, maximum=priors["chi_1"].maximum, name="a_1" ) - samples = priors.sample(100000)["a_1"] + samples = priors.sample(100000, xp=self.xp)["a_1"] ks = ks_2samp(samples, np.random.uniform(0, priors["chi_1"].maximum, 100000)) self.assertTrue(ks.pvalue > 0.001) + self.assertEqual(samples.__array_namespace__(), self.xp) if __name__ == "__main__": From 4222906a620fe0a216ab54a20adb53cf329341b6 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sun, 1 Feb 2026 22:14:41 -0500 Subject: [PATCH 076/110] FMT: precommit fixes --- bilby/compat/patches.py | 1 - bilby/core/grid.py | 2 +- bilby/core/prior/base.py | 1 - test/core/grid_test.py | 3 --- test/core/prior/base_test.py | 2 +- test/core/prior/dict_test.py | 2 +- 6 files changed, 3 insertions(+), 8 deletions(-) diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py index 7c497a24e..db18c3974 100644 --- a/bilby/compat/patches.py +++ b/bilby/compat/patches.py @@ -1,5 +1,4 @@ import array_api_compat as aac -import numpy as np from .utils import BackendNotImplementedError diff --git a/bilby/core/grid.py b/bilby/core/grid.py index de2972b10..bde6b9d73 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -345,7 +345,7 @@ def _evaluate_recursion(self, dimension, parameters): parameters[name] == self.sample_points[name])[0].item()] for name in self.parameter_names]) self._ln_likelihood[current_point] = ( -_safe_likelihood_call(self.likelihood, parameters) + _safe_likelihood_call(self.likelihood, parameters) ) else: name = self.parameter_names[dimension] diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index ea6ef5475..671b46233 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -14,7 +14,6 @@ decode_bilby_json, logger, get_dict_with_properties, - WrappedInterp1d as interp1d, ) from ...compat.utils import xp_wrap diff --git a/test/core/grid_test.py b/test/core/grid_test.py index c61a80731..d99af8327 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -8,9 +8,6 @@ import bilby from bilby.compat.patches import multivariate_logpdf -import jax -from functools import partial - class MultiGaussian(bilby.Likelihood): # set 2D multivariate Gaussian likelihood diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index d83c3edd3..21e857fa2 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -184,7 +184,7 @@ class CustomPriorWithoutXp(bilby.core.prior.Prior): def rescale(self, val): """Custom rescale without xp parameter""" return val * 2 - + prior = CustomPriorWithoutXp(name="custom_prior") import jax.numpy as jnp rescaled = prior.rescale(jnp.array([0.1, 0.2, 3])) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 59d5d83bf..96479326d 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -358,7 +358,7 @@ def test_sample(self): self.assertTrue(np.array_equal(samples1[key], samples2[key])) self.assertEqual(samples1[key].__array_namespace__(), self.xp) self.assertEqual(samples2[key].__array_namespace__(), self.xp) - + def test_prob(self): samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob( From 4bb8805b6b70263637c3a7d3a93e21c9a4061cf9 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sun, 1 Feb 2026 22:24:27 -0500 Subject: [PATCH 077/110] TEST: fix jax tests --- bilby/compat/utils.py | 12 +++++++----- bilby/core/grid.py | 1 - bilby/core/prior/dict.py | 1 - bilby/core/prior/slabspike.py | 2 -- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 5e00f8969..34b819ab9 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -55,7 +55,7 @@ def xp_wrap(func, no_xp=False): no_xp: bool If True, the decorator will not attempt to add the 'xp' keyword argument and so the wrapper is a no-op. - + Returns ======= function @@ -65,13 +65,15 @@ def xp_wrap(func, no_xp=False): def wrapped(self, *args, xp=None, **kwargs): if not no_xp and xp is None: try: + # if the user specified the target arrays in kwargs + # we need to be able to support this if len(args) > 0: - array_module = array_namespace(*args) + xp = array_module(*args) elif len(kwargs) > 0: - array_module = array_namespace(*kwargs.values()) + xp = array_module(*kwargs.values()) else: - array_module = np - kwargs["xp"] = array_module + xp = np + kwargs["xp"] = xp except TypeError: kwargs["xp"] = np elif not no_xp: diff --git a/bilby/core/grid.py b/bilby/core/grid.py index bde6b9d73..55ff2fb2d 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -331,7 +331,6 @@ def _evaluate(self): self._ln_likelihood = vmap(self.likelihood.log_likelihood)( {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)} ).reshape(self.mesh_grid[0].shape) - print(type(self._ln_likelihood)) else: self._ln_likelihood = xp.empty(self.mesh_grid[0].shape) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index cad562d16..0eaa92008 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -822,7 +822,6 @@ def prob(self, sample, *, xp=None, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - print(sample, xp) res = xp.asarray([ self[key].prob(sample[key], **self.get_required_variables(key), xp=xp) for key in sample diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 30ca16639..23aed86d3 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -96,13 +96,11 @@ def rescale(self, val, *, xp=None): slab_scaled = self._contracted_rescale( val - self.spike_height * higher_indices, xp=xp ) - print(type(slab_scaled)) res = xp.select( [lower_indices | higher_indices, intermediate_indices], [slab_scaled, self.spike_location], ) - print(type(res)) return res @xp_wrap From f34646a73cb54ceb14412b4d7bab9e7c227b7b78 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 05:48:14 -0500 Subject: [PATCH 078/110] TEST: add basic gw conversion jax tests --- test/conftest.py | 1 + test/core/prior/slabspike_test.py | 3 ++ test/gw/conversion_test.py | 46 +++++++++++++++++++++++-------- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 355c3074b..d0d1ad79b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -15,6 +15,7 @@ def pytest_addoption(parser): def pytest_configure(config): config.addinivalue_line("markers", "requires_roqs: mark a test that requires ROQs") + config.addinivalue_line("markers", "array_backend: mark that a test uses all array backends") def pytest_collection_modifyitems(config, items): diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index 501f3e39b..53eb5ebee 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -150,6 +150,7 @@ def test_inverse_cdf_below_spike_arbitrary_position(self): def test_cdf_below_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + print(slab) test_nodes = test_nodes[np.where(test_nodes < self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) @@ -158,6 +159,7 @@ def test_cdf_below_spike(self): def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): + print(slab) expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction actual = slab_spike.cdf(self.spike_loc) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) @@ -165,6 +167,7 @@ def test_cdf_at_spike(self): def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + print(slab) test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index af0c81f2f..cdd4e11b5 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -3,24 +3,27 @@ import numpy as np import pandas as pd +import pytest import bilby from bilby.gw import conversion +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicConversions(unittest.TestCase): def setUp(self): - self.mass_1 = 1.4 - self.mass_2 = 1.3 - self.mass_ratio = 13 / 14 - self.total_mass = 2.7 - self.chirp_mass = (1.4 * 1.3) ** 0.6 / 2.7 ** 0.2 - self.symmetric_mass_ratio = (1.4 * 1.3) / 2.7 ** 2 - self.cos_angle = -1 - self.angle = np.pi - self.lambda_1 = 300 - self.lambda_2 = 300 * (14 / 13) ** 5 - self.lambda_tilde = ( + self.mass_1 = self.xp.array(1.4) + self.mass_2 = self.xp.array(1.3) + self.mass_ratio = self.xp.array(13 / 14) + self.total_mass = self.xp.array(2.7) + self.chirp_mass = (self.mass_1 * self.mass_2) ** 0.6 / self.total_mass ** 0.2 + self.symmetric_mass_ratio = (self.mass_1 * self.mass_2) / self.total_mass ** 2 + self.cos_angle = self.xp.array(-1.0) + self.angle = self.xp.pi + self.lambda_1 = self.xp.array(300.0) + self.lambda_2 = self.xp.array(300.0 * (14 / 13) ** 5) + self.lambda_tilde = self.xp.array( 8 / 13 * ( @@ -39,7 +42,7 @@ def setUp(self): * (self.lambda_1 - self.lambda_2) ) ) - self.delta_lambda_tilde = ( + self.delta_lambda_tilde = self.xp.array( 1 / 2 * ( @@ -75,30 +78,36 @@ def test_total_mass_and_mass_ratio_to_component_masses(self): self.assertTrue( all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5]) ) + self.assertEqual(mass_1.__array_namespace__(), self.xp) + self.assertEqual(mass_2.__array_namespace__(), self.xp) def test_chirp_mass_and_primary_mass_to_mass_ratio(self): mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( self.chirp_mass, self.mass_1 ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(mass_ratio.__array_namespace__(), self.xp) def test_symmetric_mass_ratio_to_mass_ratio(self): mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( self.symmetric_mass_ratio ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(mass_ratio.__array_namespace__(), self.xp) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.chirp_mass_and_total_mass_to_symmetric_mass_ratio( self.chirp_mass, self.total_mass ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertEqual(symmetric_mass_ratio.__array_namespace__(), self.xp) def test_chirp_mass_and_mass_ratio_to_total_mass(self): total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( self.chirp_mass, self.mass_ratio ) self.assertAlmostEqual(self.total_mass, total_mass) + self.assertEqual(total_mass.__array_namespace__(), self.xp) def test_chirp_mass_and_mass_ratio_to_component_masses(self): mass_1, mass_2 = \ @@ -106,30 +115,37 @@ def test_chirp_mass_and_mass_ratio_to_component_masses(self): self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.mass_1, mass_1) self.assertAlmostEqual(self.mass_2, mass_2) + self.assertEqual(mass_1.__array_namespace__(), self.xp) + self.assertEqual(mass_2.__array_namespace__(), self.xp) def test_component_masses_to_chirp_mass(self): chirp_mass = conversion.component_masses_to_chirp_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.chirp_mass, chirp_mass) + self.assertEqual(chirp_mass.__array_namespace__(), self.xp) def test_component_masses_to_total_mass(self): total_mass = conversion.component_masses_to_total_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.total_mass, total_mass) + self.assertEqual(total_mass.__array_namespace__(), self.xp) def test_component_masses_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio( self.mass_1, self.mass_2 ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertEqual(symmetric_mass_ratio.__array_namespace__(), self.xp) def test_component_masses_to_mass_ratio(self): mass_ratio = conversion.component_masses_to_mass_ratio(self.mass_1, self.mass_2) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(mass_ratio.__array_namespace__(), self.xp) def test_mass_1_and_chirp_mass_to_mass_ratio(self): mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio( self.mass_1, self.chirp_mass ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertEqual(mass_ratio.__array_namespace__(), self.xp) def test_lambda_tilde_to_lambda_1_lambda_2(self): lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2( @@ -143,6 +159,8 @@ def test_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(lambda_1.__array_namespace__(), self.xp) + self.assertEqual(lambda_2.__array_namespace__(), self.xp) def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ( @@ -159,18 +177,22 @@ def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) + self.assertEqual(lambda_1.__array_namespace__(), self.xp) + self.assertEqual(lambda_2.__array_namespace__(), self.xp) def test_lambda_1_lambda_2_to_lambda_tilde(self): lambda_tilde = conversion.lambda_1_lambda_2_to_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.lambda_tilde - lambda_tilde) < 1e-5) + self.assertEqual(lambda_tilde.__array_namespace__(), self.xp) def test_lambda_1_lambda_2_to_delta_lambda_tilde(self): delta_lambda_tilde = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5) + self.assertEqual(delta_lambda_tilde.__array_namespace__(), self.xp) def test_identity_conversion(self): original_samples = dict( From 30b89a87a0353ffd498c0ac04699809ba7ddcc28 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 06:33:39 -0500 Subject: [PATCH 079/110] TEST: more debugging slab spike test --- .github/workflows/unit-tests.yml | 6 +++--- test/core/prior/slabspike_test.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 395248530..0220cc60a 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -57,13 +57,13 @@ jobs: # - name: Run precommits # run: | # pre-commit run --all-files --verbose --show-diff-on-failure - - name: Run unit tests - run: | - python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run jax-backend unit tests run: | python -m pip install .[jax] pytest --array-backend jax --durations 10 + - name: Run unit tests + run: | + python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index 53eb5ebee..6109e79f3 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -155,15 +155,17 @@ def test_cdf_below_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(actual.__array_namespace__(), self.xp) def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): print(slab) expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction - actual = slab_spike.cdf(self.spike_loc) + actual = slab_spike.cdf(self.xp.asarray(self.spike_loc)) + if isinstance(slab, bilby.core.prior.Gaussian): + print(type(expected), type(actual)) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(actual.__array_namespace__(), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): @@ -172,7 +174,7 @@ def test_cdf_above_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) np.testing.assert_allclose(expected, actual, rtol=1e-12) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(actual.__array_namespace__(), self.xp) def test_cdf_at_minimum(self): for slab_spike in self.slab_spikes: From 43cf406e50b9046ece2849d23ee6e19f9e68e7a2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 06:46:23 -0500 Subject: [PATCH 080/110] TEST: jax tests work again --- .github/workflows/unit-tests.yml | 6 +++--- test/core/prior/slabspike_test.py | 5 ----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 0220cc60a..395248530 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -57,13 +57,13 @@ jobs: # - name: Run precommits # run: | # pre-commit run --all-files --verbose --show-diff-on-failure + - name: Run unit tests + run: | + python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run jax-backend unit tests run: | python -m pip install .[jax] pytest --array-backend jax --durations 10 - - name: Run unit tests - run: | - python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index 6109e79f3..cd066ef05 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -150,7 +150,6 @@ def test_inverse_cdf_below_spike_arbitrary_position(self): def test_cdf_below_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): - print(slab) test_nodes = test_nodes[np.where(test_nodes < self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) @@ -159,17 +158,13 @@ def test_cdf_below_spike(self): def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - print(slab) expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction actual = slab_spike.cdf(self.xp.asarray(self.spike_loc)) - if isinstance(slab, bilby.core.prior.Gaussian): - print(type(expected), type(actual)) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) self.assertEqual(actual.__array_namespace__(), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): - print(slab) test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) From ba6f1ce4c7ae95ada434f8371b3be21bee9ef09c Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 06:46:41 -0500 Subject: [PATCH 081/110] DOC: add initial doc page for array backend --- docs/array_api.rst | 528 +++++++++++++++++++++++++++++++++++++++++++++ docs/index.txt | 1 + 2 files changed, 529 insertions(+) create mode 100644 docs/array_api.rst diff --git a/docs/array_api.rst b/docs/array_api.rst new file mode 100644 index 000000000..57c187d72 --- /dev/null +++ b/docs/array_api.rst @@ -0,0 +1,528 @@ +===================== +Array API Support +===================== + +Bilby now supports the Python `Array API Standard `_, +enabling the use of different array backends (NumPy, JAX, CuPy, etc.) for improved performance +and hardware acceleration. This page describes how to use this functionality and how it works internally. + +For Users and Downstream Developers +==================================== + +Overview +-------- + +The Array API support allows you to use different array libraries with Bilby seamlessly. +This can significantly improve performance, especially when using hardware accelerators like GPUs +or when you need automatic differentiation capabilities. + +**Key principle**: In most cases, you don't need to explicitly specify which array backend to use. +Bilby automatically detects the array type you're working with and uses the appropriate backend. +Simply pass JAX arrays, CuPy arrays, or NumPy arrays to prior methods, and Bilby handles the rest. + +Supported Backends +------------------ + +Bilby is currently tested with the following array backends: + +- **NumPy** (default): Standard CPU-based computations +- **JAX**: GPU/TPU acceleration and automatic differentiation + +While :code:`Bilby` should be compatible with other Array API compliant libraries, +these are not currently tested or officially supported. +If you notice any issues when using other backends, +please report them on the `Bilby GitHub repository `. + +Using Different Array Backends +------------------------------- + +Basic Prior Usage (Automatic Detection) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The array backend is automatically detected from your input arrays. You typically don't need +to specify the ``xp`` parameter:: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import numpy as np + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + + # Using JAX - backend automatically detected + val_jax = jnp.array([0.5, 1.5, 2.5]) + prob_jax = prior.prob(val_jax) # Returns JAX array + + # Using NumPy - backend automatically detected + val_np = np.array([0.5, 1.5, 2.5]) + prob_np = prior.prob(val_np) # Returns NumPy array + +Sampling with Array Backends (Explicit xp Required) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When sampling from priors, you **must** explicitly specify the array backend using the ``xp`` parameter, +as there's no input array to infer the backend from:: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + + prior = bilby.core.prior.Uniform(minimum=0, maximum=10) + samples = prior.sample(size=1000, xp=jnp) # Returns JAX array + + # Or with NumPy (default) + samples_np = prior.sample(size=1000) # Or explicitly: xp=np + +.. note:: + + Currently, prior sampling is done by first generating uniform samples in [0, 1] + using :code:`NumPy`, then converting to the desired backend. + In future releases, this may be altered to generate samples directly in the specified backend. + +Prior Dictionaries +~~~~~~~~~~~~~~~~~~ + +Prior dictionaries work the same way - automatic detection for most methods, explicit ``xp`` for sampling:: + +.. code-block:: python + + import bilby + import jax.numpy as jnp + + priors = bilby.core.prior.PriorDict({ + 'x': bilby.core.prior.Uniform(0, 100), + 'y': bilby.core.prior.Uniform(0, 1) + }) + + # Sampling requires explicit xp + samples = priors.sample(size=1000, xp=jnp) + + # Evaluation automatically detects backend from input + theta = jnp.array([50.0, 0.5]) + prob = priors.prob(samples) # Automatically uses JAX + +Core Likelihoods and Sampling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Core :code:`Bilby` likelihoods are compatible with the Array API. +When using :code:`JAX` arrays, you can take advantage of :code:`JAX`'s JIT compilation and automatic differentiation. +For :code:`JAX`-compatible samplers (e.g., :code:`numpyro`), +you can pass any :code:`JAX`-compatible :code:`Bilby` likelihood directly. +For non-:code:`JAX` samplers, you should wrap your likelihood with the +:code:`bilby.compat.jax.JittedLikelihood` class to enable JIT compilation. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + from bilby.compat.jax import JittedLikelihood + + class MyLikelihood(bilby.Likelihood): + def log_likelihood(self, parameters): + # model returns a JAX array if passed a dictionary of JAX arrays + return -0.5 * xp.sum((self.data - model(parameters))**2) + + data = jnp.array([...]) # Your data as a JAX array + + priors = bilby.core.prior.PriorDict({ + 'param1': bilby.core.prior.Uniform(0, 10), + 'param2': bilby.core.prior.Uniform(-5, 5) + }) + + likelihood = MyLikelihood(data) + + # call the likelihood once in case any initial setup is needed + likelihood.log_likelihood(priors.sample()) + + # Wrap with JittedLikelihood for JAX + jitted_likelihood = JittedLikelihood(likelihood) + + # call the jitted likelihood once to trigger JIT compilation + # the JittedLikelihood automatically converts the parameters + # to JAX arrays + jitted_likelihood.log_likelihood(priors.sample()) + + # Use with a JAX-incompatible sampler + sampler = bilby.run_sampler(likelihood=jitted_likelihood, ...) + +Gravitational-Wave Likelihoods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :code:`Bilby` implementation of gravitational-wave likelihood is compatible with the Array API, +however this requires access to waveform models that support the provided array backend. +The desired array backend must be explicitly specified for the data, +using :code:`bilby.gw.detector.networks.InterferometerList.set_array_backend`. +Below is an example using the :code:`ripplegw` package for waveform generation. +Here, an injection is performed using the standard :code:`LALSimulation` waveform generator, +and the analysis is then performed using the JIT-compiled likelihood. + +.. code-block:: python + + import bilby + import jax.numpy as jnp + import ripplegw + + priors = bilby.gw.prior.BBHPriorDict() + priors["geocent_time"] = bilby.core.prior.Uniform(1126259462.4, 1126259462.6) + injection_parameters = priors.sample() + + # Create interferometers and inject signal using standard waveform generator + ifos = bilby.gw.detector.networks.InterferometerList(['H1', 'L1']) + ifos.set_strain_data_from_power_spectral_densities( + sampling_frequency=2048, + duration=4, + start_time=injection_parameters["geocent_time"] - 2 + ) + injection_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + waveform_arguments={"approximant": "IMRPhenomXODE"} + ) + ifos.inject_signal(parameters=injection_parameters, waveform_generator=injection_wfg) + + # set the array backend after the injection + ifos.set_array_backend(jnp) + + ripple_wfg = bilby.gw.waveform_generator.WaveformGenerator( + duration=4, + sampling_frequency=2048, + frequency_domain_source_model=ripplegw.get_fd_waveform + ) + + # Create gravitational-wave likelihood + likelihood = bilby.gw.likelihood.GravitationalWaveTransient( + interferometers=ifos, + waveform_generator=ripple_wfg, + priors=priors, + phase_marginalization=True, + ) + # call the likelihood once to do some initial setup + # this is needed for the gravitational-wave transient likelihoods + likelihood.log_likelihood_ratio(priors.sample()) + + # Wrap with JittedLikelihood for JAX and JIT compile + jitted_likelihood = bilby.compat.jax.JittedLikelihood(likelihood) + jitted_likelihood.log_likelihood_ratio(priors.sample()) + +.. note:: + + All of the likelihood marginalizations implemented in :code:`Bilby` are compatible with the Array API. + However, there is currently a performance issue with the distance marginalized likelihood + using the :code:`JAX` backend. + +Performance Considerations +-------------------------- + +**When to use JAX:** + +- GPU/TPU acceleration is available +- You need automatic differentiation +- Working with large datasets or many parameters +- Repeated evaluations benefit from JIT compilation + +**When to use NumPy:** + +- Simple CPU-based computations +- Small datasets +- Maximum compatibility +- Debugging (easier to inspect values) + +**Best Practices:** + +1. Let Bilby detect the array backend automatically - only specify ``xp`` when sampling +2. Use array backend consistently throughout your analysis +3. Avoid mixing array types in the same computation +4. For JAX, consider using ``jax.jit`` for repeated computations +5. Profile your code to ensure the chosen backend provides benefits + +Bilby and JIT compilation +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Currently, Bilby functions are not JIT-compiled by default. +Additionally, many Bilby types are not defined as :code:`JAX`` :code:`PyTrees`, +and so cannot be passed as arguments to JIT-compiled functions. +We plan to support JIT-compilation for at least some Bilby types in future releases. + +Custom Priors with Array API +----------------------------- + +When creating custom priors, ensure they support the Array API: + +Example Implementation +~~~~~~~~~~~~~~~~~~~~~~ + +Always include the ``xp`` parameter with a default value:: + +... code-block:: python + + from bilby.core.prior import Prior + + class MyCustomPrior(Prior): + def __init__(self, parameter, **kwargs): + super().__init__(**kwargs) + self.parameter = parameter + + def rescale(self, val, *, xp=None): + """Rescale method with xp parameter.""" + return self.minimum + val * (self.maximum - self.minimum) * self.parameter + + def prob(self, val, *, xp=None): + """Probability method with xp parameter.""" + in_range = (val >= self.minimum) & (val <= self.maximum) + return in_range / (self.maximum - self.minimum) * self.parameter + +The ``xp`` parameter should: + +- Be a keyword-only argument (after ``*``) +- Have a default value (``None`` if method is decorated with ``@xp_wrap``, ``np`` otherwise) +- Be passed through to any array operations if used directly + +**Note**: Users of your custom prior won't need to pass ``xp`` explicitly for evaluation methods - +it will be automatically inferred from their input arrays. They only need to specify ``xp`` when sampling. + +Using the :code:`xp_wrap`` Decorator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For methods that perform array operations, use the ``@xp_wrap`` decorator:: + +.. code-block:: python + + from bilby.core.prior import Prior + from bilby.compat.utils import xp_wrap + import numpy as np + + class MyCustomPrior(Prior): + @xp_wrap + def prob(self, val, *, xp=None): + """The decorator handles xp=None automatically.""" + return xp.exp(-val) / self.normalization * self.is_in_prior_range(val) + + @xp_wrap + def ln_prob(self, val, *, xp=None): + """Works with logarithmic operations.""" + return -val - xp.log(self.normalization) + xp.log(self.is_in_prior_range(val)) + +The ``@xp_wrap`` decorator: + +- Automatically provides the appropriate array module when ``xp=None`` +- Infers the array backend from input arrays when they are :code:`JAX`/:code:`CuPy`/:code:`PyTorch` arrays +- Falls back to NumPy when the input is a standard Python type or NumPy array +- Handles the conversion seamlessly so users don't need to specify ``xp`` + +For Bilby Developers +===================== + +Architecture Overview +--------------------- + +The Array API support in Bilby is built around several key components: + +1. **The xp parameter**: A keyword-only parameter added to prior methods +2. **The @xp_wrap decorator**: Handles array module selection and injection +4. **Compatibility utilities**: Helper functions for array module detection + +Core Changes to Prior Base Class +--------------------------------- + +The ``Prior`` base class in ``bilby/core/prior/base.py`` includes these key changes: + +Method Signature Pattern +~~~~~~~~~~~~~~~~~~~~~~~~ + +All array-processing methods in prior classes follow this pattern: + +**For methods with @xp_wrap decorator**:: + +.. code-block:: python + + @xp_wrap + def prob(self, val, *, xp=None): + """Method that uses xp for array operations.""" + return xp.some_operation(val) * self.is_in_prior_range(val) + +**For methods without @xp_wrap (that use xp directly)**:: + +.. code-block:: python + + def sample(self, size=None, *, xp=np): + """Method that uses xp but isn't wrapped.""" + return xp.array(random.rng.uniform(0, 1, size)) + +Key rules: + +- ``xp`` is always keyword-only (after ``*``) +- Methods with ``@xp_wrap`` use ``xp=None`` as default +- Methods without ``@xp_wrap`` that use ``xp`` use ``xp=np`` as default +- Methods that don't use ``xp`` have ``xp=None`` as default + +The :code:`@xp_wrap`` Decorator +------------------------------- + +Located in ``bilby/compat/utils.py``, this decorator: + +1. **Inspects input arguments** to determine the array module in use +2. **Provides the appropriate xp** when ``xp=None`` +3. **Maintains backward compatibility** with code that doesn't pass ``xp`` + +Example implementation pattern:: + +... code-block:: python + + from bilby.compat.utils import xp_wrap + + @xp_wrap + def my_function(val, *, xp=None): + # When called: + # - If xp=None, decorator infers from val + # - If xp is provided, uses that + # - Returns results in the same array type as input + return xp.exp(val) / xp.mean(val) + +Testing Array API Support +------------------------- + +Test Structure +~~~~~~~~~~~~~~ + +When appropriate, tests should verify functionality across different +backends using the ``array_backend`` marker:: + + @pytest.mark.array_backend + @pytest.mark.usefixtures("xp_class") + class TestMyPrior: + def test_prob(self): + prior = MyPrior() + val = self.xp.array([0.5, 1.5, 2.5]) + # No need to pass xp - automatically detected + prob = prior.prob(val) + assert self.xp.all(prob >= 0) + assert prob.__array_namespace__() == self.xp + + def test_sample(self): + prior = MyPrior() + # Sampling requires explicit xp + samples = prior.sample(size=100, xp=self.xp) + assert samples.__array_namespace__() == self.xp + +The array_backend Marker +~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``@pytest.mark.array_backend`` marker is used to indicate that a test or test class should be run +with multiple array backends. When you run pytest with the ``--array-backend`` flag, only tests marked +with ``array_backend`` will be executed with that specific backend. + +Without the marker, tests run with the default NumPy backend only. With the marker: + +- Tests are parametrized to run with different backends +- The ``xp_class`` fixture is available, providing access to the array module via ``self.xp`` +- Tests verify that code works correctly regardless of the array backend + +Running Tests with Different Backends +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the ``--array-backend`` flag to test with specific backends:: + + # Test with NumPy (default) + pytest test/core/prior/analytical_test.py + + # Test with JAX backend + pytest --array-backend jax test/core/prior/analytical_test.py + + # Test with CuPy backend + pytest --array-backend cupy test/core/prior/analytical_test.py + +Bilby automatically sets ``SCIPY_ARRAY_API=1`` on import, so you don't need to set this +environment variable manually. The ``--array-backend`` flag controls which backend the +``xp_class`` fixture provides to your tests. + +Migration Guide from Previous Versions +-------------------------------------- + +Key Differences +~~~~~~~~~~~~~~~ + +1. **Method signatures changed**: All prior methods now include ``xp`` parameter +2. **Decorator added**: Many methods now use ``@xp_wrap`` +3. **Default values differ**: Methods with ``@xp_wrap`` use ``xp=None``, others use ``xp=np`` +4. **Validation added**: Custom priors are checked for ``xp`` support + +Best Practices for Contributors +-------------------------------- + +When adding or modifying prior methods: + +1. **Always include xp parameter** in prob, ln_prob, rescale, cdf, sample methods +2. **Use @xp_wrap decorator** for methods doing array operations +3. **Set correct default**: ``xp=None`` with decorator, ``xp=np`` without (for methods that use xp directly) +4. **Pass xp through**: When calling other methods, pass ``xp=xp`` +5. **Test with multiple backends**: Use ``@pytest.mark.array_backend`` and test with ``--array-backend jax`` +6. **Document xp parameter**: Note it in docstrings, but emphasize it's usually auto-detected +7. **Use array module functions**: Use ``xp.function()`` not ``np.function()`` in wrapped methods + +Handling Array Updates with :code:`array_api_extra.at`` +------------------------------------------------------- + +One key difference between array backends is how they handle array updates. +NumPy allows in-place modification of array slices, +while JAX requires functional updates since arrays are immutable. +The ``array_api_extra.at`` function provides a unified interface for array updates across backends. + +Usage Examples +~~~~~~~~~~~~~~ + +**Conditional update**:: + +.. code-block:: python + + @xp_wrap + def conditional_update(vals, *, xp=None): + """Update array elements where mask is True.""" + arr = vals**2 + mask = arr > 0.5 + # Instead of: arr[mask] = value + arr = xpx.at(arr)[mask].set(value) + return arr + +**Increment operation**:: + +.. code-block:: python + + @xp_wrap + def increment_slice(arr, *, xp=None): + """Add values to a slice of an array.""" + # Instead of: arr[2:5] += values + arr = xpx.at(arr)[2:5].add(values) + return arr + +Available Operations +~~~~~~~~~~~~~~~~~~~~ + +The ``at`` function supports several operations: + +- ``set(values)``: Replace values at specified indices +- ``add(values)``: Add values to specified indices +- ``multiply(values)``: Multiply specified indices by values +- ``min(values)``: Take element-wise minimum +- ``max(values)``: Take element-wise maximum + +Important Notes +~~~~~~~~~~~~~~~ + +1. **Return value**: Always use the returned array. The operation may create a new array (JAX) or modify in-place (NumPy). + +2. **Import**: Import ``array_api_extra`` at the module level:: + +.. code-block:: python + + import array_api_extra as xpx + +Further Resources +----------------- + +- `Array API Standard `_ +- `JAX Documentation `_ +- `array-api-compat Package `_ +- `array-api-extra Package `_ diff --git a/docs/index.txt b/docs/index.txt index ff6e12c85..d8fabb550 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -16,6 +16,7 @@ Welcome to bilby's documentation! prior likelihood samplers + array_api dynesty-guide bilby-mcmc-guide rng From 0a6f1e26462832ad6d8ac46d5c8dd736590df707 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 10:58:06 -0500 Subject: [PATCH 082/110] TEST: add a bunch of gw tests --- bilby/gw/utils.py | 27 +++-- bilby/gw/waveform_generator.py | 3 +- test/conftest.py | 15 ++- test/gw/conversion_test.py | 55 ++++++---- test/gw/detector/geometry_test.py | 23 +++- test/gw/likelihood_test.py | 171 +++++++++++++++++++++-------- test/gw/prior_test.py | 7 +- test/gw/utils_test.py | 87 +++++++++++---- test/gw/waveform_generator_test.py | 47 ++++++-- 9 files changed, 311 insertions(+), 124 deletions(-) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 07a9e9aa1..7799c23ca 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -8,7 +8,7 @@ from .geometry import zenith_azimuth_to_theta_phi from .time import greenwich_mean_sidereal_time -from ..compat.utils import array_module +from ..compat.utils import array_module, xp_wrap from ..core.utils import (logger, run_commandline, check_directory_exists_and_if_not_mkdir, SamplesSummary, theta_phi_to_ra_dec) @@ -31,7 +31,7 @@ def asd_from_freq_series(freq_data, df): array_like: array of real-valued normalized frequency domain ASD data """ - return np.absolute(freq_data) * 2 * df**0.5 + return abs(freq_data) * 2 * df**0.5 def psd_from_freq_series(freq_data, df): @@ -51,7 +51,7 @@ def psd_from_freq_series(freq_data, df): array_like: Real-valued normalized frequency domain PSD data """ - return np.power(asd_from_freq_series(freq_data, df), 2) + return asd_from_freq_series(freq_data, df) ** 2 def get_vertex_position_geocentric(latitude, longitude, elevation): @@ -221,7 +221,7 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non """ low_index = int(lower_cut_off / delta_frequency) up_index = int(upper_cut_off / delta_frequency) - integrand = np.conj(signal_a) * signal_b + integrand = signal_a.conjugate() * signal_b integrand = integrand[low_index:up_index] / power_spectral_density[low_index:up_index] integral = (4 * delta_frequency * integrand) / norm_a / norm_b return sum(integral).real @@ -247,8 +247,6 @@ def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos): ra, dec: float The zenith and azimuthal angles in the sky frame. """ - # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex - # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, ifos) gmst = greenwich_mean_sidereal_time(geocent_time) ra, dec = theta_phi_to_ra_dec(theta, phi, gmst) @@ -924,7 +922,8 @@ def lalsim_SimNeutronStarLoveNumberK2(mass_in_SI, fam): return SimNeutronStarLoveNumberK2(mass_in_SI, fam) -def spline_angle_xform(delta_psi): +@xp_wrap +def spline_angle_xform(delta_psi, *, xp=None): """ Returns the angle in degrees corresponding to the spline calibration parameters delta_psi. @@ -941,7 +940,7 @@ def spline_angle_xform(delta_psi): """ rotation = (2.0 + 1.0j * delta_psi) / (2.0 - 1.0j * delta_psi) - return 180.0 / np.pi * np.arctan2(np.imag(rotation), np.real(rotation)) + return 180.0 / np.pi * xp.arctan2(xp.imag(rotation), xp.real(rotation)) def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=None, xform=None): @@ -1002,7 +1001,8 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= plt.xlim(freq_points.min() - .5, freq_points.max() + 50) -def ln_i0(value): +@xp_wrap +def ln_i0(value, *, xp=None): """ A numerically stable method to evaluate ln(I_0) a modified Bessel function of order 0 used in the phase-marginalized likelihood. @@ -1017,7 +1017,6 @@ def ln_i0(value): array-like: The natural logarithm of the bessel function """ - xp = array_module(value) return xp.log(i0e(value)) + xp.abs(value) @@ -1047,10 +1046,10 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): import lalsimulation return safety * lalsimulation.SimInspiralTaylorF2ReducedSpinChirpTime( - frequency, - mass_1 * solar_mass, - mass_2 * solar_mass, - chi, + float(frequency), + float(mass_1 * solar_mass), + float(mass_2 * solar_mass), + float(chi), -1 ) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 7cae71618..4043caa0d 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -1,3 +1,4 @@ +import array_api_compat as aac import numpy as np from ..core import utils @@ -201,7 +202,7 @@ def _strain_from_transformed_model( transformed_model_data_points, transformed_model, parameters ) - if isinstance(transformed_model_strain, np.ndarray): + if aac.is_array_api_obj(transformed_model_strain): return transformation_function(transformed_model_strain, self.sampling_frequency) model_strain = dict() diff --git a/test/conftest.py b/test/conftest.py index d0d1ad79b..b1668561a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,14 +21,17 @@ def pytest_configure(config): def pytest_collection_modifyitems(config, items): if config.getoption("--skip-roqs"): skip_roqs = pytest.mark.skip(reason="Skipping tests that require ROQs") - for item in items: - if "requires_roqs" in item.keywords: - item.add_marker(skip_roqs) + else: + skip_roqs = pytest.mark.noop if config.getoption("--array-backend") is not None: array_only = pytest.mark.skip(reason="Only running backend dependent tests") - for item in items: - if "array_backend" not in item.keywords: - item.add_marker(array_only) + else: + array_only = pytest.mark.noop + for item in items: + if "requires_roqs" in item.keywords and config.getoption("--skip-roqs"): + item.add_marker(skip_roqs) + elif "array_backend" not in item.keywords: + item.add_marker(array_only) def _xp(request): diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index cdd4e11b5..621a766ae 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -626,19 +626,20 @@ def test_comoving_luminosity_with_cosmology(self): self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGenerateMassParameters(unittest.TestCase): def setUp(self): - self.expected_values = {'mass_1': 2.0, - 'mass_2': 1.0, - 'chirp_mass': 1.2167286837864113, - 'total_mass': 3.0, - 'mass_1_source': 4.0, - 'mass_2_source': 2.0, - 'chirp_mass_source': 2.433457367572823, - 'total_mass_source': 6, - 'symmetric_mass_ratio': 0.2222222222222222, - 'mass_ratio': 0.5} - + self.expected_values = {'mass_1': self.xp.array(2.0), + 'mass_2': self.xp.array(1.0), + 'chirp_mass': self.xp.array(1.2167286837864113), + 'total_mass': self.xp.array(3.0), + 'mass_1_source': self.xp.array(4.0), + 'mass_2_source': self.xp.array(2.0), + 'chirp_mass_source': self.xp.array(2.433457367572823), + 'total_mass_source': self.xp.array(6), + 'symmetric_mass_ratio': self.xp.array(0.2222222222222222), + 'mass_ratio': self.xp.array(0.5)} def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses local_test_vars = \ @@ -685,6 +686,10 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): ) for key in local_all_mass_parameters.keys(): self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) + self.assertEqual( + local_all_mass_parameters[key].__array_namespace__(), + self.xp, + ) def test_from_mass_1_and_mass_2(self): self.helper_generation_from_keys(["mass_1", "mass_2"], @@ -751,6 +756,8 @@ def test_from_chirp_mass_source_and_symmetric_mass_2(self): self.expected_values, source=True) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestEquationOfStateConversions(unittest.TestCase): ''' Class to test equation of state conversions. @@ -759,48 +766,48 @@ class TestEquationOfStateConversions(unittest.TestCase): ''' def setUp(self): - self.mass_1_source_spectral = [ + self.mass_1_source_spectral = self.xp.array([ 4.922542724434885, 4.350626907771598, 4.206155335439082, 1.7822696459661311, 1.3091740103047926 - ] - self.mass_2_source_spectral = [ + ]) + self.mass_2_source_spectral = self.xp.array([ 3.459974694590303, 1.2276461777181447, 3.7287707089639976, 0.3724016563531846, 1.055042934805801 - ] - self.spectral_pca_gamma_0 = [ + ]) + self.spectral_pca_gamma_0 = self.xp.array([ 0.7074873121348357, 0.05855931126849878, 0.7795329261793462, 1.467907561566463, 2.9066488405635624 - ] - self.spectral_pca_gamma_1 = [ + ]) + self.spectral_pca_gamma_1 = self.xp.array([ -0.29807111670823816, 2.027708558522935, -1.4415775226512115, -0.7104870098896858, -0.4913817181089619 - ] - self.spectral_pca_gamma_2 = [ + ]) + self.spectral_pca_gamma_2 = self.xp.array([ 0.25625095371021156, -0.19574096643220049, -0.2710238103460012, 0.22815820981582358, -0.1543413205016374 - ] - self.spectral_pca_gamma_3 = [ + ]) + self.spectral_pca_gamma_3 = self.xp.array([ -0.04030365100175101, 0.05698030777919032, -0.045595911403040264, -0.023480394227900117, -0.07114492992285618 - ] + ]) self.spectral_gamma_0 = [ 1.1259406796075457, 0.3191335618787259, @@ -905,6 +912,8 @@ def test_spectral_pca_to_spectral(self): self.assertAlmostEqual(spectral_gamma_1, self.spectral_gamma_1[i], places=5) self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) + for val in [spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3]: + self.assertEqual(val.__array_namespace__(), self.xp) def test_spectral_params_to_lambda_1_lambda_2(self): ''' diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 358825b23..4906f00cc 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -2,10 +2,13 @@ from unittest import mock import numpy as np +import pytest import bilby +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestInterferometerGeometry(unittest.TestCase): def setUp(self): self.length = 30 @@ -26,6 +29,7 @@ def setUp(self): xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt, ) + self.geometry.set_array_backend(self.xp) def tearDown(self): del self.length @@ -40,27 +44,35 @@ def tearDown(self): def test_length_setting(self): self.assertEqual(self.geometry.length, self.length) + self.assertEqual(self.geometry.length.__array_namespace__(), self.xp) def test_latitude_setting(self): self.assertEqual(self.geometry.latitude, self.latitude) + self.assertEqual(self.geometry.latitude.__array_namespace__(), self.xp) def test_longitude_setting(self): self.assertEqual(self.geometry.longitude, self.longitude) + self.assertEqual(self.geometry.longitude.__array_namespace__(), self.xp) def test_elevation_setting(self): self.assertEqual(self.geometry.elevation, self.elevation) + self.assertEqual(self.geometry.elevation.__array_namespace__(), self.xp) def test_xarm_azi_setting(self): self.assertEqual(self.geometry.xarm_azimuth, self.xarm_azimuth) + self.assertEqual(self.geometry.xarm_azimuth.__array_namespace__(), self.xp) def test_yarm_azi_setting(self): self.assertEqual(self.geometry.yarm_azimuth, self.yarm_azimuth) + self.assertEqual(self.geometry.yarm_azimuth.__array_namespace__(), self.xp) def test_xarm_tilt_setting(self): self.assertEqual(self.geometry.xarm_tilt, self.xarm_tilt) + self.assertEqual(self.geometry.xarm_tilt.__array_namespace__(), self.xp) def test_yarm_tilt_setting(self): self.assertEqual(self.geometry.yarm_tilt, self.yarm_tilt) + self.assertEqual(self.geometry.yarm_tilt.__array_namespace__(), self.xp) def test_vertex_without_update(self): _ = self.geometry.vertex @@ -142,31 +154,37 @@ def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) def test_unit_vector_along_arm_default(self): with self.assertRaises(ValueError): @@ -177,17 +195,20 @@ def test_unit_vector_along_arm_x(self): self.geometry.latitude = 0 self.geometry.xarm_tilt = 0 self.geometry.xarm_azimuth = 0 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("x") self.assertTrue(np.allclose(arm, np.array([0, 1, 0]))) + self.assertEqual(arm.__array_namespace__(), self.xp) def test_unit_vector_along_arm_y(self): self.geometry.longitude = 0 self.geometry.latitude = 0 self.geometry.yarm_tilt = 0 self.geometry.yarm_azimuth = 90 + self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("y") - print(arm) self.assertTrue(np.allclose(arm, np.array([0, 0, 1]))) + self.assertEqual(arm.__array_namespace__(), self.xp) def test_repr(self): expected = ( diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 95dbf75e4..087e00848 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -10,9 +10,40 @@ import h5py import numpy as np import bilby +from array_api_compat import is_array_api_obj from bilby.gw.likelihood import BilbyROQParamsRangeError +class BackendWaveformGenerator(bilby.gw.waveform_generator.WaveformGenerator): + """A thin wrapper to emulate different backends in the waveform generator.""" + def __init__(self, wfg, xp): + self.wfg = wfg + self.xp = xp + + def __getattr__(self, name): + if name == "xp": + return self.xp + return getattr(self.wfg, name) + + def convert_nested_dict(self, data): + if is_array_api_obj(data): + return self.xp.array(data) + elif isinstance(data, dict): + return {key: self.convert_nested_dict(value) for key, value in data.items()} + else: + raise ValueError("Input must be an array API object or a dict of such objects.") + + def frequency_domain_strain(self, parameters): + wf = self.wfg.frequency_domain_strain(parameters) + return self.convert_nested_dict(wf) + + def time_domain_strain(self, parameters): + wf = self.wfg.time_domain_strain(parameters) + return self.convert_nested_dict(wf) + + +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestBasicGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) @@ -37,11 +68,13 @@ def setUp(self): self.interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=2048, duration=4 ) - self.waveform_generator = bilby.gw.waveform_generator.GWSignalWaveformGenerator( - duration=4, - sampling_frequency=2048, + self.interferometers.set_array_backend(self.xp) + base_wfg = bilby.gw.waveform_generator.GWSignalWaveformGenerator( + duration=self.xp.array(4.0), + sampling_frequency=self.xp.array(2048.0), waveform_arguments=dict(waveform_approximant="IMRPhenomPv2"), ) + self.waveform_generator = BackendWaveformGenerator(base_wfg, self.xp) self.likelihood = bilby.gw.likelihood.BasicGravitationalWaveTransient( interferometers=self.interferometers, @@ -57,29 +90,33 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, nll, 3 ) + self.assertEqual(nll.__array_namespace__(), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood() - self.assertAlmostEqual(self.likelihood.log_likelihood(), -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(logl, -4032.4397343470005, 3) + self.assertEqual(logl.__array_namespace__(), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( - self.likelihood.log_likelihood() - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(), + self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), + llr, 3, ) + self.assertEqual(llr.__array_namespace__(), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the waveform is None""" self.likelihood.waveform_generator.frequency_domain_strain = lambda x: None - self.assertEqual(self.likelihood.log_likelihood_ratio(), np.nan_to_num(-np.inf)) + self.assertEqual(self.likelihood.log_likelihood_ratio(self.parameters), np.nan_to_num(-np.inf)) def test_repr(self): expected = "BasicGravitationalWaveTransient(interferometers={},\n\twaveform_generator={})".format( @@ -88,11 +125,13 @@ def test_repr(self): self.assertEqual(expected, repr(self.likelihood)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) - self.duration = 4 - self.sampling_frequency = 2048 + self.duration = self.xp.array(4.0) + self.sampling_frequency = self.xp.array(2048.0) self.parameters = dict( mass_1=31.0, mass_2=29.0, @@ -114,11 +153,13 @@ def setUp(self): self.interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) - self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + self.interferometers.set_array_backend(self.xp) + wfg = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, ) + self.waveform_generator = BackendWaveformGenerator(wfg, self.xp) self.prior = bilby.gw.prior.BBHPriorDict() self.prior["geocent_time"] = bilby.prior.Uniform( @@ -131,7 +172,6 @@ def setUp(self): waveform_generator=self.waveform_generator, priors=self.prior.copy(), ) - self.likelihood.parameters = self.parameters.copy() def tearDown(self): del self.parameters @@ -142,30 +182,33 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" - self.likelihood.noise_log_likelihood() + nll = self.likelihood.noise_log_likelihood() self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 + -4014.1787704539474, nll, 3 ) + self.assertEqual(nll.__array_namespace__(), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" - self.likelihood.log_likelihood() - self.assertAlmostEqual(self.likelihood.log_likelihood(), - -4032.4397343470005, 3) + logl = self.likelihood.log_likelihood(self.parameters) + self.assertAlmostEqual(logl, -4032.4397343470005, 3) + self.assertEqual(logl.__array_namespace__(), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" + llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( - self.likelihood.log_likelihood() - self.likelihood.noise_log_likelihood(), - self.likelihood.log_likelihood_ratio(), + self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), + llr, 3, ) + self.assertEqual(llr.__array_namespace__(), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the waveform is None""" self.likelihood.waveform_generator.frequency_domain_strain = lambda x: None - self.assertEqual(self.likelihood.log_likelihood_ratio(), np.nan_to_num(-np.inf)) + self.assertEqual(self.likelihood.log_likelihood_ratio(self.parameters), np.nan_to_num(-np.inf)) def test_repr(self): expected = ( @@ -239,14 +282,16 @@ def test_reference_frame_agrees_with_default(self): ) parameters = self.parameters.copy() del parameters["ra"], parameters["dec"] - parameters["zenith"] = 1.0 - parameters["azimuth"] = 1.0 + parameters["zenith"] = self.xp.array(1.0) + parameters["azimuth"] = self.xp.array(1.0) parameters["ra"], parameters["dec"] = bilby.gw.utils.zenith_azimuth_to_ra_dec( zenith=parameters["zenith"], azimuth=parameters["azimuth"], geocent_time=parameters["geocent_time"], - ifos=bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos=new_likelihood.reference_frame, ) + self.assertEqual(parameters["ra"].__array_namespace__(), self.xp) + self.assertEqual(parameters["dec"].__array_namespace__(), self.xp) self.assertEqual( new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) @@ -286,10 +331,12 @@ def test_time_reference_agrees_with_default(self): @pytest.mark.requires_roqs +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestROQLikelihood(unittest.TestCase): def setUp(self): - self.duration = 4 - self.sampling_frequency = 2048 + self.duration = self.xp.array(4.0) + self.sampling_frequency = self.xp.array(2048.0) self.test_parameters = dict( mass_1=36.0, @@ -313,6 +360,7 @@ def setUp(self): ifos.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration ) + ifos.set_array_backend(self.xp) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") @@ -332,6 +380,7 @@ def setUp(self): waveform_approximant="IMRPhenomPv2", ), ) + non_roq_wfg = BackendWaveformGenerator(non_roq_wfg, self.xp) ifos.inject_signal( parameters=self.test_parameters, waveform_generator=non_roq_wfg @@ -392,7 +441,7 @@ def roq_wfg(self): fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" fnodes_linear = np.load(fnodes_linear_file).T fnodes_quadratic = np.load(fnodes_quadratic_file).T - return bilby.gw.waveform_generator.WaveformGenerator( + wfg = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, @@ -403,6 +452,7 @@ def roq_wfg(self): waveform_approximant="IMRPhenomPv2", ), ) + return BackendWaveformGenerator(wfg, self.xp) @cached_property def roq(self): @@ -426,13 +476,14 @@ def roq_phase(self): ) def test_matches_non_roq(self): + roq_llr = self.roq.log_likelihood_ratio(self.test_parameters) self.assertLess( abs( - self.non_roq.log_likelihood_ratio(self.test_parameters) - - self.roq.log_likelihood_ratio(self.test_parameters) + self.non_roq.log_likelihood_ratio(self.test_parameters) - roq_llr ) / self.non_roq.log_likelihood_ratio(self.test_parameters), 1e-3, ) + self.assertEqual(roq_llr.__array_namespace__(), self.xp) self.non_roq.parameters.update(self.test_parameters) self.roq.parameters.update(self.test_parameters) self.assertLess( @@ -457,10 +508,12 @@ def test_create_roq_weights_with_params(self): quadratic_matrix=self.quadratic_matrix_file, priors=self.priors, ) + roq_llr = roq.log_likelihood_ratio(self.test_parameters) self.assertEqual( - roq.log_likelihood_ratio(self.test_parameters), + roq_llr, self.roq.log_likelihood_ratio(self.test_parameters) ) + self.assertEqual(roq_llr.__array_namespace__(), self.xp) roq.parameters.update(self.test_parameters) self.roq.parameters.update(self.test_parameters) self.assertEqual(roq.log_likelihood_ratio(), self.roq.log_likelihood_ratio()) @@ -647,6 +700,8 @@ def test_rescaling(self): @pytest.mark.requires_roqs +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestROQLikelihoodHDF5(unittest.TestCase): """ Test ROQ likelihood constructed from .hdf5 basis @@ -660,9 +715,9 @@ class TestROQLikelihoodHDF5(unittest.TestCase): _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5" def setUp(self): - self.minimum_frequency = 20 - self.sampling_frequency = 2048 - self.duration = 16 + self.minimum_frequency = self.xp.array(20.0) + self.sampling_frequency = self.xp.array(2048.0) + self.duration = self.xp.array(16.0) self.reference_frequency = 20.0 self.waveform_approximant = "IMRPhenomD" # The SNRs of injections are 130-160 for roq_scale_factor=1 and 70-80 for roq_scale_factor=2 @@ -704,10 +759,11 @@ def test_fails_with_frequency_duration_mismatch( self.priors["chirp_mass"].maximum = 9 interferometers = bilby.gw.detector.InterferometerList(["H1"]) interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=2 * maximum_frequency, - duration=duration, - start_time=self.injection_parameters["geocent_time"] - duration + 1 + sampling_frequency=self.xp.array(2 * maximum_frequency), + duration=self.xp.array(duration), + start_time=self.xp.array(self.injection_parameters["geocent_time"] - duration + 1) ) + interferometers.set_array_backend(self.xp) for ifo in interferometers: ifo.minimum_frequency = minimum_frequency ifo.maximum_frequency = maximum_frequency @@ -737,10 +793,11 @@ def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): self.priors["chirp_mass"].maximum = chirp_mass_max interferometers = bilby.gw.detector.InterferometerList(["H1"]) interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.sampling_frequency, - duration=self.duration, - start_time=self.injection_parameters["geocent_time"] - self.duration + 1 + sampling_frequency=self.xp.array(self.sampling_frequency), + duration=self.xp.array(self.duration), + start_time=self.xp.array(self.injection_parameters["geocent_time"] - self.duration + 1) ) + interferometers.set_array_backend(self.xp) for ifo in interferometers: ifo.minimum_frequency = self.minimum_frequency search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( @@ -879,6 +936,7 @@ def assertLess_likelihood_errors( self.priors["chirp_mass"].maximum = mc_max interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"]) + interferometers.set_array_backend(self.xp) for ifo in interferometers: if minimum_frequency is None: ifo.minimum_frequency = self.minimum_frequency @@ -920,6 +978,7 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + waveform_generator = BackendWaveformGenerator(waveform_generator, self.xp) interferometers.inject_signal(waveform_generator=waveform_generator, parameters=self.injection_parameters) likelihood = bilby.gw.GravitationalWaveTransient( @@ -937,6 +996,7 @@ def assertLess_likelihood_errors( waveform_approximant=self.waveform_approximant ) ) + search_waveform_generator = BackendWaveformGenerator(search_waveform_generator, self.xp) likelihood_roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=interferometers, priors=self.priors, @@ -951,6 +1011,7 @@ def assertLess_likelihood_errors( llr = likelihood.log_likelihood_ratio(parameters) llr_roq = likelihood_roq.log_likelihood_ratio(parameters) self.assertLess(np.abs(llr - llr_roq), max_llr_error) + self.assertEqual(llr_roq.__array_namespace__(), self.xp) likelihood.parameters.update(parameters) likelihood_roq.parameters.update(parameters) llr = likelihood.log_likelihood_ratio() @@ -1283,11 +1344,13 @@ def test_instantiation(self): self.like = bilby.gw.likelihood.get_binary_black_hole_likelihood(self.ifos) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestMBLikelihood(unittest.TestCase): def setUp(self): - self.duration = 16 - self.fmin = 20. - self.sampling_frequency = 2048. + self.duration = self.xp.array(16.0) + self.fmin = self.xp.array(20.0) + self.sampling_frequency = self.xp.array(2048.0) self.test_parameters = dict( chirp_mass=6.0, mass_ratio=0.5, @@ -1315,6 +1378,7 @@ def setUp(self): ) for ifo in self.ifos: ifo.minimum_frequency = self.fmin + self.ifos.set_array_backend(self.xp) spline_calibration_nodes = 10 self.calibration_parameters = {} @@ -1370,6 +1434,7 @@ def test_matches_original_likelihood( reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg = BackendWaveformGenerator(wfg, self.xp) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( @@ -1379,6 +1444,7 @@ def test_matches_original_likelihood( reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg ) @@ -1391,10 +1457,12 @@ def test_matches_original_likelihood( parameters = deepcopy(self.test_parameters) if add_cal_errors: parameters.update(self.calibration_parameters) + llmb = likelihood_mb.log_likelihood_ratio(parameters) self.assertLess( - abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), + abs(likelihood.log_likelihood_ratio(parameters) - llmb), tolerance ) + self.assertEqual(llmb.__array_namespace__(), self.xp) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) self.assertLess( @@ -1414,6 +1482,7 @@ def test_large_accuracy_factor(self): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg = BackendWaveformGenerator(wfg, self.xp) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( @@ -1423,6 +1492,7 @@ def test_large_accuracy_factor(self): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg ) @@ -1559,12 +1629,14 @@ def test_inout_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg_mb, weights=filepath ) # likelihood_mb_from_weights.parameters.update(self.test_parameters) llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio(self.test_parameters) + self.assertEqual(llr_from_weights.__array_namespace__(), self.xp) self.assertAlmostEqual(llr, llr_from_weights) @@ -1581,6 +1653,7 @@ def test_from_dict_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg = BackendWaveformGenerator(wfg, self.xp) self.ifos.inject_signal( parameters=self.test_parameters, waveform_generator=wfg ) @@ -1592,6 +1665,7 @@ def test_from_dict_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg_mb, reference_chirp_mass=self.test_parameters['chirp_mass'], @@ -1599,6 +1673,7 @@ def test_from_dict_weights(self, linear_interpolation): ) likelihood_mb.parameters.update(self.test_parameters) llr = likelihood_mb.log_likelihood_ratio() + self.assertEqual(llr.__array_namespace__(), self.xp) # reset waveform generator to check if likelihood recovered from the weights properly adds banded # frequency points to waveform arguments @@ -1609,12 +1684,14 @@ def test_from_dict_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) weights = likelihood_mb.weights likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg_mb, weights=weights ) # likelihood_mb_from_weights.parameters.update(self.test_parameters) llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio(self.test_parameters) + self.assertEqual(llr_from_weights.__array_namespace__(), self.xp) self.assertAlmostEqual(llr, llr_from_weights) @@ -1639,6 +1716,7 @@ def test_matches_original_likelihood_low_maximum_frequency( reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg = BackendWaveformGenerator(wfg, self.xp) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( @@ -1648,6 +1726,7 @@ def test_matches_original_likelihood_low_maximum_frequency( reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg ) @@ -1660,10 +1739,12 @@ def test_matches_original_likelihood_low_maximum_frequency( parameters = deepcopy(self.test_parameters) if add_cal_errors: parameters.update(self.calibration_parameters) + llrmb = likelihood_mb.log_likelihood_ratio(parameters) self.assertLess( - abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), + abs(likelihood.log_likelihood_ratio(parameters) - llrmb), tolerance ) + self.assertEqual(llrmb.__array_namespace__(), self.xp) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) self.assertLess( diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index 022bc0a56..c4afa00f4 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -222,6 +222,8 @@ def test_pickle_prior(self): self.assertEqual(priors, priors_loaded) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestPriorConversion(unittest.TestCase): def test_bilby_to_lalinference(self): mass_1 = [1, 20] @@ -256,10 +258,11 @@ def test_bilby_to_lalinference(self): ) nsamples = 5000 - bilby_samples = bilby_prior.sample(nsamples) + bilby_samples = bilby_prior.sample(nsamples, xp=self.xp) bilby_samples, _ = conversion.convert_to_lal_binary_black_hole_parameters( bilby_samples ) + bilby_samples = pd.DataFrame(bilby_samples) # Quicker way to generate LA prior samples (rather than specifying Constraint) lalinf_samples = [] @@ -279,7 +282,7 @@ def test_bilby_to_lalinference(self): result.search_parameter_keys = ["mass_ratio", "chirp_mass"] result.meta_data = dict() result.priors = bilby_prior - result.posterior = pd.DataFrame(bilby_samples) + result.posterior = bilby_samples result_converted = bilby.gw.prior.convert_to_flat_in_component_mass_prior( result, fraction=0.1 ) diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index cf78849c7..d082a01a1 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -15,6 +15,8 @@ from bilby.gw import utils as gwutils +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestGWUtils(unittest.TestCase): def setUp(self): self.outdir = "outdir" @@ -27,29 +29,32 @@ def tearDown(self): pass def test_asd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.array([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) + self.assertEqual(asd.__array_namespace__(), self.xp) def test_psd_from_freq_series(self): - freq_data = np.array([1, 2, 3]) + freq_data = self.xp.array([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) + self.assertEqual(psd.__array_namespace__(), self.xp) def test_inner_product(self): - aa = np.array([1, 2, 3]) - bb = np.array([5, 6, 7]) - frequency = np.array([0.2, 0.4, 0.6]) + aa = self.xp.array([1, 2, 3]) + bb = self.xp.array([5, 6, 7]) + frequency = self.xp.array([0.2, 0.4, 0.6]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() ip = gwutils.inner_product(aa, bb, frequency, PSD) self.assertEqual(ip, 0) + self.assertEqual(ip.__array_namespace__(), self.xp) def test_noise_weighted_inner_product(self): - aa = np.array([1e-23, 2e-23, 3e-23]) - bb = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + aa = self.xp.array([1e-23, 2e-23, 3e-23]) + bb = self.xp.array([5e-23, 6e-23, 7e-23]) + frequency = self.xp.array([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -60,11 +65,12 @@ def test_noise_weighted_inner_product(self): gwutils.optimal_snr_squared(aa, psd, duration), gwutils.noise_weighted_inner_product(aa, aa, psd, duration), ) + self.assertEqual(nwip.__array_namespace__(), self.xp) def test_matched_filter_snr(self): - signal = np.array([1e-23, 2e-23, 3e-23]) - frequency_domain_strain = np.array([5e-23, 6e-23, 7e-23]) - frequency = np.array([100, 101, 102]) + signal = self.xp.array([1e-23, 2e-23, 3e-23]) + frequency_domain_strain = self.xp.array([5e-23, 6e-23, 7e-23]) + frequency = self.xp.array([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -73,6 +79,27 @@ def test_matched_filter_snr(self): signal, frequency_domain_strain, psd, duration ) self.assertEqual(mfsnr, 25.510869054168282) + self.assertEqual(mfsnr.__array_namespace__(), self.xp) + + def test_overlap(self): + signal = self.xp.linspace(1e-23, 21e-23, 21) + frequency_domain_strain = self.xp.linspace(5e-23, 25e-23, 21) + frequency = self.xp.linspace(100, 120, 21) + PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() + psd = PSD.power_spectral_density_interpolated(frequency) + duration = 4 + overlap = gwutils.overlap( + signal, + frequency_domain_strain, + psd, + delta_frequency=1 / duration, + lower_cut_off=3, + upper_cut_off=18, + norm_a=gwutils.optimal_snr_squared(signal, psd, duration), + norm_b=gwutils.optimal_snr_squared(frequency_domain_strain, psd, duration), + ) + self.assertAlmostEqual(overlap, 2.76914407e-05) + self.assertEqual(overlap.__array_namespace__(), self.xp) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): @@ -264,6 +291,8 @@ def test_safe_cast_mode_to_int(self): gwutils.safe_cast_mode_to_int(None) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestSkyFrameConversion(unittest.TestCase): def setUp(self) -> None: @@ -281,23 +310,37 @@ def tearDown(self) -> None: del self.ifos del self.samples + def test_conversion_single(self) -> None: + sample = self.priors.sample() + zenith = self.xp.asarray(sample["zenith"]) + azimuth = self.xp.asarray(sample["azimuth"]) + time = self.xp.asarray(sample["time"]) + self.ifos.set_array_backend(self.xp) + ra, dec = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zenith, azimuth, time, self.ifos + ) + self.assertEqual(ra.__array_namespace__(), self.xp) + self.assertEqual(dec.__array_namespace__(), self.xp) + def test_conversion_gives_correct_prior(self) -> None: - zeniths = self.samples["zenith"] - azimuths = self.samples["azimuth"] - times = self.samples["time"] - args = zip(*[ - (zenith, azimuth, time, self.ifos) - for zenith, azimuth, time in zip(zeniths, azimuths, times) - ]) - ras, decs = zip(*map(bilby.gw.utils.zenith_azimuth_to_ra_dec, *args)) + zeniths = self.xp.asarray(self.samples["zenith"]) + azimuths = self.xp.asarray(self.samples["azimuth"]) + times = self.xp.asarray(self.samples["time"]) + self.ifos.set_array_backend(self.xp) + ras, decs = bilby.gw.utils.zenith_azimuth_to_ra_dec( + zeniths, azimuths, times, self.ifos + ) self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) + self.assertEqual(ras.__array_namespace__(), self.xp) + self.assertEqual(decs.__array_namespace__(), self.xp) -def test_ln_i0_mathces_scipy(): +@pytest.mark.array_backend +def test_ln_i0_mathces_scipy(xp): from scipy.special import i0 - values = np.linspace(-10, 10, 101) - assert max(abs(gwutils.ln_i0(values) - np.log(i0(values)))) < 1e-10 + values = xp.linspace(-10, 10, 101) + assert max(abs(gwutils.ln_i0(values) - xp.log(i0(values)))) < 1e-10 if __name__ == "__main__": diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index f63b40537..efd59d352 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -4,6 +4,8 @@ import bilby import lalsimulation import numpy as np +import pytest +from bilby.compat.utils import xp_wrap def dummy_func_array_return_value( @@ -36,16 +38,21 @@ def dummy_func_dict_return_value( return ht +@xp_wrap def dummy_func_array_return_value_2( - array, amplitude, mu, sigma, ra, dec, geocent_time, psi + array, amplitude, mu, sigma, ra, dec, geocent_time, psi, *, xp=None ): - return dict(plus=np.array(array), cross=np.array(array)) + return dict(plus=xp.array(array), cross=xp.array(array)) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, frequency_domain_source_model=dummy_func_dict_return_value + self.xp.array(1.0), + self.xp.array(4096.0), + frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -118,9 +125,11 @@ def conversion_func(): def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) + self.assertEqual(self.waveform_generator.duration.__array_namespace__(), self.xp) def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) + self.assertEqual(self.waveform_generator.sampling_frequency.__array_namespace__(), self.xp) def test_source_model(self): self.assertEqual( @@ -129,10 +138,10 @@ def test_source_model(self): ) def test_frequency_array_type(self): - self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray) + self.assertIsInstance(self.waveform_generator.frequency_array, self.xp.ndarray) def test_time_array_type(self): - self.assertIsInstance(self.waveform_generator.time_array, np.ndarray) + self.assertIsInstance(self.waveform_generator.time_array, self.xp.ndarray) def test_source_model_parameters(self): self.waveform_generator.parameters = self.simulation_parameters.copy() @@ -301,11 +310,13 @@ def conversion_func(): self.assertEqual(conversion_func, self.waveform_generator.parameter_conversion) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.array(1.0), + sampling_frequency=self.xp.array(4096.0), frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -347,6 +358,8 @@ def test_frequency_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(actual["plus"].__array_namespace__(), self.xp) + self.assertEqual(actual["cross"].__array_namespace__(), self.xp) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None @@ -364,6 +377,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(actual.__array_namespace__(), self.xp) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None @@ -382,6 +396,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(actual["plus"].__array_namespace__(), self.xp) + self.assertEqual(actual["cross"].__array_namespace__(), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -491,8 +507,8 @@ def test_frequency_domain_caching_changing_model(self): def test_time_domain_caching_changing_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=1, - sampling_frequency=4096, + duration=self.xp.array(1.0), + sampling_frequency=self.xp.array(4096.0), time_domain_source_model=dummy_func_dict_return_value, ) original_waveform = self.waveform_generator.frequency_domain_strain( @@ -507,12 +523,18 @@ def test_time_domain_caching_changing_model(self): self.assertFalse( np.array_equal(original_waveform["plus"], new_waveform["plus"]) ) + self.assertEqual(new_waveform["plus"].__array_namespace__(), self.xp) + self.assertEqual(new_waveform["cross"].__array_namespace__(), self.xp) +@pytest.mark.array_backend +@pytest.mark.usefixtures("xp_class") class TestTimeDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - 1, 4096, time_domain_source_model=dummy_func_dict_return_value + self.xp.array(1.0), + self.xp.array(4096.0), + time_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( amplitude=1e-21, @@ -553,6 +575,8 @@ def test_time_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(actual["plus"].__array_namespace__(), self.xp) + self.assertEqual(actual["cross"].__array_namespace__(), self.xp) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None @@ -572,6 +596,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) + self.assertEqual(actual.__array_namespace__(), self.xp) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None @@ -592,6 +617,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) + self.assertEqual(actual["plus"].__array_namespace__(), self.xp) + self.assertEqual(actual["cross"].__array_namespace__(), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None From aafbf1b98aad68422de8e44893ce5a4a3b447f8f Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 10:58:17 -0500 Subject: [PATCH 083/110] DOC: fix doc page formatting --- docs/array_api.rst | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/docs/array_api.rst b/docs/array_api.rst index 57c187d72..f3d293b02 100644 --- a/docs/array_api.rst +++ b/docs/array_api.rst @@ -40,7 +40,7 @@ Basic Prior Usage (Automatic Detection) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The array backend is automatically detected from your input arrays. You typically don't need -to specify the ``xp`` parameter:: +to specify the ``xp`` parameter: .. code-block:: python @@ -62,7 +62,7 @@ Sampling with Array Backends (Explicit xp Required) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When sampling from priors, you **must** explicitly specify the array backend using the ``xp`` parameter, -as there's no input array to infer the backend from:: +as there's no input array to infer the backend from: .. code-block:: python @@ -84,7 +84,7 @@ as there's no input array to infer the backend from:: Prior Dictionaries ~~~~~~~~~~~~~~~~~~ -Prior dictionaries work the same way - automatic detection for most methods, explicit ``xp`` for sampling:: +Prior dictionaries work the same way - automatic detection for most methods, explicit ``xp`` for sampling: .. code-block:: python @@ -237,12 +237,14 @@ Performance Considerations 3. Avoid mixing array types in the same computation 4. For JAX, consider using ``jax.jit`` for repeated computations 5. Profile your code to ensure the chosen backend provides benefits +6. If you find :code:`xp_wrap` is a bottleneck in your code, you can explicitly pass + :code:`xp` to the function/method to skip the automatic backend detection step. Bilby and JIT compilation ~~~~~~~~~~~~~~~~~~~~~~~~~ Currently, Bilby functions are not JIT-compiled by default. -Additionally, many Bilby types are not defined as :code:`JAX`` :code:`PyTrees`, +Additionally, many Bilby types are not defined as :code:`JAX` :code:`PyTrees`, and so cannot be passed as arguments to JIT-compiled functions. We plan to support JIT-compilation for at least some Bilby types in future releases. @@ -254,9 +256,9 @@ When creating custom priors, ensure they support the Array API: Example Implementation ~~~~~~~~~~~~~~~~~~~~~~ -Always include the ``xp`` parameter with a default value:: +Always include the ``xp`` parameter with a default value: -... code-block:: python +.. code-block:: python from bilby.core.prior import Prior @@ -286,7 +288,7 @@ it will be automatically inferred from their input arrays. They only need to spe Using the :code:`xp_wrap`` Decorator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For methods that perform array operations, use the ``@xp_wrap`` decorator:: +For methods that perform array operations, use the ``@xp_wrap`` decorator: .. code-block:: python @@ -334,7 +336,7 @@ Method Signature Pattern All array-processing methods in prior classes follow this pattern: -**For methods with @xp_wrap decorator**:: +**For methods with @xp_wrap decorator**: .. code-block:: python @@ -343,7 +345,7 @@ All array-processing methods in prior classes follow this pattern: """Method that uses xp for array operations.""" return xp.some_operation(val) * self.is_in_prior_range(val) -**For methods without @xp_wrap (that use xp directly)**:: +**For methods without @xp_wrap (that use xp directly)**: .. code-block:: python @@ -367,9 +369,9 @@ Located in ``bilby/compat/utils.py``, this decorator: 2. **Provides the appropriate xp** when ``xp=None`` 3. **Maintains backward compatibility** with code that doesn't pass ``xp`` -Example implementation pattern:: +Example implementation pattern: -... code-block:: python +.. code-block:: python from bilby.compat.utils import xp_wrap @@ -388,7 +390,9 @@ Test Structure ~~~~~~~~~~~~~~ When appropriate, tests should verify functionality across different -backends using the ``array_backend`` marker:: +backends using the ``array_backend`` marker: + +.. code-block:: python @pytest.mark.array_backend @pytest.mark.usefixtures("xp_class") @@ -473,7 +477,7 @@ The ``array_api_extra.at`` function provides a unified interface for array updat Usage Examples ~~~~~~~~~~~~~~ -**Conditional update**:: +**Conditional update**: .. code-block:: python @@ -486,7 +490,7 @@ Usage Examples arr = xpx.at(arr)[mask].set(value) return arr -**Increment operation**:: +**Increment operation**: .. code-block:: python @@ -513,7 +517,7 @@ Important Notes 1. **Return value**: Always use the returned array. The operation may create a new array (JAX) or modify in-place (NumPy). -2. **Import**: Import ``array_api_extra`` at the module level:: +2. **Import**: Import ``array_api_extra`` at the module level: .. code-block:: python From 8ca99788463f461cf48304d6ac29b7c7e833beb2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 11:13:07 -0500 Subject: [PATCH 084/110] FMT: fix formatting --- test/gw/conversion_test.py | 1 + test/gw/likelihood_test.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index 621a766ae..4961c11e0 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -640,6 +640,7 @@ def setUp(self): 'total_mass_source': self.xp.array(6), 'symmetric_mass_ratio': self.xp.array(0.2222222222222222), 'mass_ratio': self.xp.array(0.5)} + def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses local_test_vars = \ diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 087e00848..6acf60757 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -24,7 +24,7 @@ def __getattr__(self, name): if name == "xp": return self.xp return getattr(self.wfg, name) - + def convert_nested_dict(self, data): if is_array_api_obj(data): return self.xp.array(data) @@ -36,7 +36,7 @@ def convert_nested_dict(self, data): def frequency_domain_strain(self, parameters): wf = self.wfg.frequency_domain_strain(parameters) return self.convert_nested_dict(wf) - + def time_domain_strain(self, parameters): wf = self.wfg.time_domain_strain(parameters) return self.convert_nested_dict(wf) From bdda315bdf5a045a4f2cacd813653589d47f2e03 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 11:29:57 -0500 Subject: [PATCH 085/110] BUG: fix typo in bilby_cython call --- bilby/gw/compat/cython.py | 2 +- test/conftest.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bilby/gw/compat/cython.py b/bilby/gw/compat/cython.py index f42301875..9d0a69af0 100644 --- a/bilby/gw/compat/cython.py +++ b/bilby/gw/compat/cython.py @@ -47,7 +47,7 @@ def get_polarization_tensor(ra: Real, dec: Real, time: Real, psi: Real, mode: st @dispatch(precedence=1) def rotation_matrix_from_delta(delta: ArrayLike): - return _geometry.rotation_matrix_from_delta_x(delta) + return _geometry.rotation_matrix_from_delta(delta) @dispatch(precedence=1) diff --git a/test/conftest.py b/test/conftest.py index b1668561a..f7f5a17a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -22,15 +22,15 @@ def pytest_collection_modifyitems(config, items): if config.getoption("--skip-roqs"): skip_roqs = pytest.mark.skip(reason="Skipping tests that require ROQs") else: - skip_roqs = pytest.mark.noop + skip_roqs = None if config.getoption("--array-backend") is not None: array_only = pytest.mark.skip(reason="Only running backend dependent tests") else: - array_only = pytest.mark.noop + array_only = None for item in items: if "requires_roqs" in item.keywords and config.getoption("--skip-roqs"): item.add_marker(skip_roqs) - elif "array_backend" not in item.keywords: + elif "array_backend" not in item.keywords and array_only is not None: item.add_marker(array_only) From 5bf699e660b7bf066f66c6c8de990418fbe5eb48 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 11:32:49 -0500 Subject: [PATCH 086/110] BUG: fix list input for asd calculation --- bilby/gw/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 7799c23ca..013ecc581 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -15,7 +15,8 @@ from ..core.utils.constants import solar_mass -def asd_from_freq_series(freq_data, df): +@xp_wrap +def asd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the ASD from the frequency domain output of gaussian_noise() @@ -31,10 +32,11 @@ def asd_from_freq_series(freq_data, df): array_like: array of real-valued normalized frequency domain ASD data """ - return abs(freq_data) * 2 * df**0.5 + return xp.abs(freq_data) * 2 * df**0.5 -def psd_from_freq_series(freq_data, df): +@xp_wrap +def psd_from_freq_series(freq_data, df, *, xp=None): """ Calculate the PSD from the frequency domain output of gaussian_noise() Calls asd_from_freq_series() and squares the output @@ -51,7 +53,7 @@ def psd_from_freq_series(freq_data, df): array_like: Real-valued normalized frequency domain PSD data """ - return asd_from_freq_series(freq_data, df) ** 2 + return asd_from_freq_series(freq_data, df, xp=xp) ** 2 def get_vertex_position_geocentric(latitude, longitude, elevation): From acd22a37f809217c0c9ead38acb7edc97aa3f645 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 12:02:16 -0500 Subject: [PATCH 087/110] FMT: fix syntax for array conversion and backend checks --- bilby/compat/patches.py | 4 +- bilby/core/grid.py | 15 ++-- bilby/core/likelihood.py | 13 ++-- bilby/core/prior/analytical.py | 22 +++--- bilby/core/prior/base.py | 2 +- bilby/core/prior/conditional.py | 2 +- bilby/core/prior/dict.py | 6 +- bilby/core/prior/joint.py | 12 ++-- bilby/core/utils/calculus.py | 2 +- bilby/core/utils/io.py | 3 +- bilby/gw/conversion.py | 8 +-- bilby/gw/detector/calibration.py | 9 +-- bilby/gw/detector/geometry.py | 16 ++--- bilby/gw/geometry.py | 22 +++--- bilby/gw/likelihood/roq.py | 3 +- bilby/gw/prior.py | 6 +- bilby/gw/source.py | 2 +- bilby/gw/time.py | 2 +- bilby/gw/utils.py | 2 +- bilby/hyper/likelihood.py | 2 +- docs/array_api.rst | 8 +-- test/conftest.py | 10 +-- test/core/grid_test.py | 6 +- test/core/likelihood_test.py | 51 +++++++------- test/core/prior/analytical_test.py | 105 +++++++++++++++------------- test/core/prior/base_test.py | 3 +- test/core/prior/conditional_test.py | 15 ++-- test/core/prior/dict_test.py | 25 +++---- test/core/prior/prior_test.py | 49 ++++++------- test/core/prior/slabspike_test.py | 22 +++--- test/core/result_test.py | 2 +- test/core/series_test.py | 24 +++---- test/core/utils_test.py | 67 +++++++++--------- test/gw/conversion_test.py | 93 ++++++++++++------------ test/gw/detector/geometry_test.py | 33 ++++----- test/gw/likelihood_test.py | 75 ++++++++++---------- test/gw/prior_test.py | 5 +- test/gw/utils_test.py | 43 ++++++------ test/gw/waveform_generator_test.py | 47 +++++++------ 39 files changed, 434 insertions(+), 402 deletions(-) diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py index db18c3974..19ad0565a 100644 --- a/bilby/compat/patches.py +++ b/bilby/compat/patches.py @@ -30,7 +30,7 @@ def multivariate_logpdf(xp, mean, cov): elif aac.is_torch_namespace(xp): from torch.distributions.multivariate_normal import MultivariateNormal - mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.array(cov)) + mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov)) logpdf = mvn.log_prob else: raise BackendNotImplementedError @@ -39,7 +39,7 @@ def multivariate_logpdf(xp, mean, cov): def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=None): if xp is None: - xp = a.__array_namespace__() + xp = aac.get_namespace(a) if "jax" in xp.__name__: # the scipy version of logsumexp cannot be vmapped diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 55ff2fb2d..2377dc0d5 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -1,6 +1,7 @@ import json import os +import array_api_compat as aac import numpy as np from .likelihood import _safe_likelihood_call @@ -88,7 +89,7 @@ def __init__( enumerate(self.parameter_names)}, axis=0).reshape( self.mesh_grid[0].shape) else: - self._ln_prior = xp.array(0.0) + self._ln_prior = xp.asarray(0.0) self._ln_likelihood = None # evaluate the likelihood on the grid points @@ -207,7 +208,7 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): non_marg_names.remove(name) places = self.sample_points[name] - xp = log_array.__array_namespace__() + xp = aac.get_namespace(log_array) if len(places) > 1: dx = xp.diff(places) @@ -218,7 +219,7 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): # no marginalisation required, just remove the singleton dimension z = log_array.shape q = xp.arange(0, len(z)).astype(int) != axis - out = xp.reshape(log_array, tuple((xp.array(list(z)))[q])) + out = xp.reshape(log_array, tuple((xp.asarray(list(z)))[q])) return out @@ -296,7 +297,7 @@ def marginalize_likelihood(self, parameters=None, not_parameters=None): """ ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) - xp = ln_like.__array_namespace__() + xp = aac.get_namespace(ln_like) # NOTE: the output will not be properly normalised return xp.exp(ln_like - xp.max(ln_like)) @@ -321,11 +322,11 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): ln_post = self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised - xp = ln_post.__array_namespace__() + xp = aac.get_namespace(ln_post) return xp.exp(ln_post - xp.max(ln_post)) def _evaluate(self): - xp = self.mesh_grid[0].__array_namespace__() + xp = aac.get_namespace(self.mesh_grid[0]) if xp.__name__ == "jax.numpy": from jax import vmap self._ln_likelihood = vmap(self.likelihood.log_likelihood)( @@ -339,7 +340,7 @@ def _evaluate(self): def _evaluate_recursion(self, dimension, parameters): if dimension == self.n_dims: - xp = self.mesh_grid[0].__array_namespace__() + xp = aac.get_namespace(self.mesh_grid[0]) current_point = tuple([[xp.where( parameters[name] == self.sample_points[name])[0].item()] for name in self.parameter_names]) diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 6dc025510..da333f534 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -3,6 +3,7 @@ import os from warnings import warn +import array_api_compat as aac import numpy as np from array_api_compat import is_array_api_obj from scipy.special import gammaln, xlogy @@ -350,7 +351,7 @@ def log_likelihood(self, parameters=None): raise ValueError( "Poisson rate function returns wrong value type! " "Is {} when it should be numpy.ndarray".format(type(rate))) - xp = rate.__array_namespace__() + xp = aac.get_namespace(rate) if xp.any(rate < 0.): raise ValueError(("Poisson rate function returns a negative", " value!")) @@ -371,7 +372,7 @@ def y(self): def y(self, y): if not is_array_api_obj(y): y = np.atleast_1d(y) - xp = y.__array_namespace__() + xp = aac.get_namespace(y) # check array is a non-negative integer array if y.dtype.kind not in 'ui' or xp.any(y < 0): raise ValueError("Data must be non-negative integers") @@ -399,7 +400,7 @@ def __init__(self, x, y, func, **kwargs): def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - xp = mu.__array_namespace__() + xp = aac.get_namespace(mu) if xp.any(mu < 0.): return -np.inf return -xp.sum(xp.log(mu) + (self.y / mu)) @@ -416,7 +417,7 @@ def y(self): def y(self, y): if not is_array_api_obj(y): y = np.atleast_1d(y) - xp = y.__array_namespace__() + xp = aac.get_namespace(y) if xp.any(y < 0): raise ValueError("Data must be non-negative") self._y = y @@ -584,7 +585,7 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) xp = array_module(self.cov) - x = xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) return self.logpdf(x) @@ -624,7 +625,7 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) xp = array_module(self.cov) - x = xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + x = xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) return -xp.log(2) + xp.logaddexp(self.logpdf_1(x), self.logpdf_2(x)) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 118e36ce8..9dd120964 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1441,7 +1441,7 @@ def __init__( xp = array_module(values) nvalues = len(values) - values = xp.array(values) + values = xp.asarray(values) if values.shape != (nvalues,): raise ValueError( f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}" @@ -1461,7 +1461,7 @@ def __init__( self.values = self._values_array.tolist() weights = ( - xp.array(weights) / xp.sum(weights) + xp.asarray(weights) / xp.sum(weights) if weights is not None else xp.ones(self.nvalues) / self.nvalues ) @@ -1497,7 +1497,7 @@ def rescale(self, val, *, xp=None): ======= Union[float, array_like]: Rescaled probability """ - index = xp.searchsorted(self._cumulative_weights_array[1:], val) + index = xp.searchsorted(xp.asarray(self._cumulative_weights_array[1:]), val) return xp.asarray(self._values_array)[index] @xp_wrap @@ -1512,7 +1512,7 @@ def cdf(self, val, *, xp=None): ======= float: cumulative prior probability of val """ - index = xp.searchsorted(self._values_array, val, side="right") + index = xp.searchsorted(xp.asarray(self._values_array), val, side="right") return xp.asarray(self._cumulative_weights_array)[index] @xp_wrap @@ -1527,9 +1527,13 @@ def prob(self, val, *, xp=None): ======= float: Prior probability of val """ - index = xp.searchsorted(self._values_array, val) + index = xp.searchsorted(xp.asarray(self._values_array), val) index = xp.clip(index, 0, self.nvalues - 1) - p = xp.where(self._values_array[index] == val, self._weights_array[index], 0) + p = xp.where( + xp.asarray(self._values_array[index])== val, + xp.asarray(self._weights_array[index]), + xp.asarray(0.0), + ) # turn 0d numpy array to scalar return p[()] @@ -1546,10 +1550,12 @@ def ln_prob(self, val, xp=None): float: """ - index = xp.searchsorted(self._values_array, val) + index = xp.searchsorted(xp.asarray(self._values_array), val) index = xp.clip(index, 0, self.nvalues - 1) lnp = xp.where( - self._values_array[index] == val, self._lnweights_array[index], -np.inf + xp.asarray(self._values_array[index]) == val, + xp.asarray(self._lnweights_array[index]), + -np.inf, ) # turn 0d array to scalar return lnp[()] diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 671b46233..137f6df5d 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -155,7 +155,7 @@ def sample(self, size=None, *, xp=np): from ..utils import random self.least_recently_sampled = self.rescale( - xp.array(random.rng.uniform(0, 1, size)) + xp.asarray(random.rng.uniform(0, 1, size)) ) return self.least_recently_sampled diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index c221939b8..f42f83239 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -80,7 +80,7 @@ def sample(self, size=None, *, xp=np, **required_variables): from ..utils import random self.least_recently_sampled = self.rescale( - xp.array(random.rng.uniform(0, 1, size)), + xp.asarray(random.rng.uniform(0, 1, size)), xp=xp, **required_variables, ) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 0eaa92008..a7ca7589f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -385,7 +385,7 @@ def sample_subset_constrained_as_array(self, keys=iter([]), size=None, *, xp=np) samples_dict = self.sample_subset_constrained(keys=keys, size=size, xp=xp) samples_dict = {key: xp.atleast_1d(val) for key, val in samples_dict.items()} samples_list = [samples_dict[key] for key in keys] - return xp.array(samples_list) + return xp.asarray(samples_list) def sample_subset(self, keys=iter([]), size=None, *, xp=np): """Draw samples from the prior set for parameters which are not a DeltaFunction @@ -850,7 +850,7 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) xp = array_module(sample.values()) - res = xp.array([ + res = xp.asarray([ self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample ]) @@ -910,7 +910,7 @@ def rescale(self, keys, theta): self[subkey].least_recently_sampled = val result[subkey] = val - return xp.array([result[key] for key in keys]) + return xp.asarray([result[key] for key in keys]) def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 0e8e8abfa..924a24c6e 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -317,7 +317,7 @@ def rescale(self, value, *, xp=None, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = xp.array(value) + samp = xp.asarray(value) if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -627,7 +627,7 @@ def _rescale(self, samp, *, xp=None, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = xp.array(self.mus[mode]) + self.sigmas[mode] * xp.einsum( + samp = xp.asarray(self.mus[mode]) + self.sigmas[mode] * xp.einsum( "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] ) return samp @@ -674,7 +674,7 @@ def _sample(self, size, *, xp=np, **kwargs): if not outbound: inbound = True - return xp.array(samps) + return xp.asarray(samps) @xp_wrap def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): @@ -801,7 +801,7 @@ def rescale(self, val, *, xp=None, **kwargs): self.dist.rescale_parameters[self.name] = val if self.dist.filled_rescale(): - values = xp.array(list(self.dist.rescale_parameters.values())).T + values = xp.asarray(list(self.dist.rescale_parameters.values())).T samples = self.dist.rescale(values, **kwargs) self.dist.reset_rescale() return samples @@ -871,14 +871,14 @@ def ln_prob(self, val, *, xp=None): # check for the same number of values for each parameter shapes = set() for v in values: - shapes.add(xp.array(v).shape) + shapes.add(xp.asarray(v).shape) if len(shapes) > 1: raise ValueError( "Each parameter must have the same " "number of requested values." ) - lnp = self.dist.ln_prob(xp.array(values).T) + lnp = self.dist.ln_prob(xp.asarray(values).T) # reset the requested parameters self.dist.reset_request() diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 889e086b0..6852299ec 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -194,7 +194,7 @@ def logtrapzexp(lnf, dx, *, xp=np): else: raise TypeError("Step size must be a single value or array-like") - return C + logsumexp(xp.array([logsumexp(lnfdx1), logsumexp(lnfdx2)])) + return C + logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) class interp1d(_interp1d): diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index a5502a1a6..d24b16fe2 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -8,6 +8,7 @@ from pathlib import Path from datetime import timedelta +import array_api_compat as aac import numpy as np import pandas as pd @@ -62,7 +63,7 @@ def default(self, obj): if hasattr(obj, "__array_namespace__"): return { "__array__": True, - "__array_namespace__": obj.__array_namespace__().__name__, + "__array_namespace__": aac.get_namespace(obj).__name__, "content": obj.tolist(), } if isinstance(obj, complex): diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index f3f1e5118..ad751c4fa 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -632,8 +632,8 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) ''' xp = array_module(gamma_pca_0) - sampled_pca_gammas = xp.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) - transformation_matrix = xp.array( + sampled_pca_gammas = xp.asarray([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) + transformation_matrix = xp.asarray( [ [0.43801, -0.76705, 0.45143, 0.12646], [-0.53573, 0.17169, 0.67968, 0.47070], @@ -642,8 +642,8 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) ] ) - model_space_mean = xp.array([0.89421, 0.33878, -0.07894, 0.00393]) - model_space_standard_deviation = xp.array([0.35700, 0.25769, 0.05452, 0.00312]) + model_space_mean = xp.asarray([0.89421, 0.33878, -0.07894, 0.00393]) + model_space_standard_deviation = xp.asarray([0.35700, 0.25769, 0.05452, 0.00312]) converted_gamma_parameters = \ model_space_mean + model_space_standard_deviation * xp.dot(transformation_matrix, sampled_pca_gammas) diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 6f7390bde..883275016 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -42,6 +42,7 @@ import copy import os +import array_api_compat as aac import numpy as np import pandas as pd from array_api_compat import is_jax_namespace @@ -333,9 +334,9 @@ def __repr__(self): def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): """Evaluate Eq. (1) in https://dcc.ligo.org/LIGO-T2300140""" xp = array_module(self.params[f"{kind}_0"]) - parameters = xp.array([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) + parameters = xp.asarray([self.params[f"{kind}_{ii}"] for ii in range(self.n_points)]) next_nodes = previous_nodes + 1 - nodes = xp.array(self.nodes_to_spline_coefficients) + nodes = xp.asarray(self.nodes_to_spline_coefficients) spline_coefficients = nodes.dot(parameters) return ( a * parameters[previous_nodes] @@ -377,7 +378,7 @@ def get_calibration_factor(self, frequency_array, **params): delta_amplitude = self._evaluate_spline("amplitude", a, b, c, d, previous_nodes) delta_phase = self._evaluate_spline("phase", a, b, c, d, previous_nodes) calibration_factor = (1 + delta_amplitude) * (2 + 1j * delta_phase) / (2 - 1j * delta_phase) - xp = calibration_factor.__array_namespace__() + xp = aac.get_namespace(calibration_factor) return xp.nan_to_num(calibration_factor) @@ -412,7 +413,7 @@ def get_calibration_factor(self, frequency_array, **params): if idx is None: raise KeyError(f"Calibration index for {self.label} not found.") - xp = frequency_array.__array_namespace__() + xp = aac.get_namespace(frequency_array) if not xp.array_equal(frequency_array, self.frequency_array): intersection, mask, _ = xp.intersect1d( frequency_array, self.frequency_array, return_indices=True diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index 5d0de9b9f..a6c2df168 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -306,11 +306,11 @@ def unit_vector_along_arm(self, arm): raise ValueError("Arm must either be 'x' or 'y'.") def set_array_backend(self, xp): - self.length = xp.array(self.length) - self.latitude = xp.array(self.latitude) - self.longitude = xp.array(self.longitude) - self.elevation = xp.array(self.elevation) - self.xarm_azimuth = xp.array(self.xarm_azimuth) - self.yarm_azimuth = xp.array(self.yarm_azimuth) - self.xarm_tilt = xp.array(self.xarm_tilt) - self.yarm_tilt = xp.array(self.yarm_tilt) + self.length = xp.asarray(self.length) + self.latitude = xp.asarray(self.latitude) + self.longitude = xp.asarray(self.longitude) + self.elevation = xp.asarray(self.elevation) + self.xarm_azimuth = xp.asarray(self.xarm_azimuth) + self.yarm_azimuth = xp.asarray(self.yarm_azimuth) + self.xarm_tilt = xp.asarray(self.xarm_tilt) + self.yarm_tilt = xp.asarray(self.yarm_tilt) diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index c07ec5c0c..54d2f3a1d 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -30,15 +30,15 @@ def antenna_response(detector_tensor, ra, dec, time, psi, mode): def calculate_arm(arm_tilt, arm_azimuth, longitude, latitude): """""" xp = array_module(arm_tilt) - e_long = xp.array([-xp.sin(longitude), xp.cos(longitude), longitude * 0]) - e_lat = xp.array( + e_long = xp.asarray([-xp.sin(longitude), xp.cos(longitude), longitude * 0]) + e_lat = xp.asarray( [ -xp.sin(latitude) * xp.cos(longitude), -xp.sin(latitude) * xp.sin(longitude), xp.cos(latitude), ] ) - e_h = xp.array( + e_h = xp.asarray( [ xp.cos(latitude) * xp.cos(longitude), xp.cos(latitude) * xp.sin(longitude), @@ -70,17 +70,17 @@ def get_polarization_tensor(ra, dec, time, psi, mode): gmst = greenwich_mean_sidereal_time(time) % (2 * xp.pi) phi = ra - gmst theta = xp.atleast_1d(xp.pi / 2 - dec).squeeze() - u = xp.array( + u = xp.asarray( [ xp.cos(phi) * xp.cos(theta), xp.cos(theta) * xp.sin(phi), -xp.sin(theta) * xp.ones_like(phi), ] ) - v = xp.array([ + v = xp.asarray([ -xp.sin(phi), xp.cos(phi), xp.zeros_like(phi) ]) * xp.ones_like(theta) - omega = xp.array([ + omega = xp.asarray([ xp.sin(xp.pi - theta) * xp.cos(xp.pi + phi), xp.sin(xp.pi - theta) * xp.sin(xp.pi + phi), xp.cos(xp.pi - theta) * xp.ones_like(phi), @@ -124,21 +124,21 @@ def rotation_matrix_from_delta(delta_x): alpha = xp.arctan2(-delta_x[1] * delta_x[2], delta_x[0]) beta = xp.arccos(delta_x[2]) gamma = xp.arctan2(delta_x[1], delta_x[0]) - rotation_1 = xp.array( + rotation_1 = xp.asarray( [ [xp.cos(alpha), -xp.sin(alpha), xp.zeros(alpha.shape)], [xp.sin(alpha), xp.cos(alpha), xp.zeros(alpha.shape)], [xp.zeros(alpha.shape), xp.zeros(alpha.shape), xp.ones(alpha.shape)], ] ) - rotation_2 = xp.array( + rotation_2 = xp.asarray( [ [xp.cos(beta), xp.zeros(beta.shape), xp.sin(beta)], [xp.zeros(beta.shape), xp.ones(beta.shape), xp.zeros(beta.shape)], [-xp.sin(beta), xp.zeros(beta.shape), xp.cos(beta)], ] ) - rotation_3 = xp.array( + rotation_3 = xp.asarray( [ [xp.cos(gamma), -xp.sin(gamma), xp.zeros(gamma.shape)], [xp.sin(gamma), xp.cos(gamma), xp.zeros(gamma.shape)], @@ -163,7 +163,7 @@ def time_delay_geocentric(detector1, detector2, ra, dec, time): speed_of_light = 299792458.0 phi = ra - gmst theta = xp.pi / 2 - dec - omega = xp.array( + omega = xp.asarray( [xp.sin(theta) * xp.cos(phi), xp.sin(theta) * xp.sin(phi), xp.cos(theta)] ) delta_d = detector2 - detector1 @@ -181,7 +181,7 @@ def time_delay_from_geocenter(detector1, ra, dec, time): def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): """""" xp = array_module(delta_x) - omega_prime = xp.array( + omega_prime = xp.asarray( [ xp.sin(zenith) * xp.cos(azimuth), xp.sin(zenith) * xp.sin(azimuth), diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 8f81f950b..9e9bdb28e 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -1,4 +1,5 @@ +import array_api_compat as aac import numpy as np from .base import GravitationalWaveTransient @@ -567,7 +568,7 @@ def _interp_five_samples(time_samples, values, time): value: float The value of the function at the input time """ - xp = time_samples.__array_namespace__() + xp = aac.get_namespace(time_samples) r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. r2 = values[2] - 2. * values[3] + values[4] a = (time_samples[3] - time) / xp.maximum(time_samples[1] - time_samples[0], 1e-12) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 3fe9ed242..cd0cf182d 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1554,11 +1554,11 @@ def _rescale(self, samp, *, xp=np, **kwargs): for i, val in enumerate(pix_rescale): theta, ra = self.hp.pix2ang(self.nside, int(round(val))) dec = 0.5 * np.pi - theta - sample = xpx.at(sample, i).set(xp.array(self.draw_from_pixel(ra, dec, int(round(val))))) + sample = xpx.at(sample, i).set(xp.asarray(self.draw_from_pixel(ra, dec, int(round(val))))) if self.distance: self.update_distance(int(round(val))) dist_samples = xpx.at(dist_samples, i).set( - xp.array(self.distance_icdf(dist_samp[i])) + xp.asarray(self.distance_icdf(dist_samp[i])) ) if self.distance: sample = xp.vstack([sample[:, 0], sample[:, 1], dist_samples]) @@ -1638,7 +1638,7 @@ def _sample(self, size, *, xp=np, **kwargs): sample[samp, :] = [ra_dec[0], ra_dec[1], dist] else: sample[samp, :] = self.draw_from_pixel(ra, dec, sample_pix[samp]) - return xp.array(sample.reshape((-1, self.num_vars))) + return xp.asarray(sample.reshape((-1, self.num_vars))) def draw_distance(self, pix): """ diff --git a/bilby/gw/source.py b/bilby/gw/source.py index 96973efd9..cc08d2d65 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -1292,7 +1292,7 @@ def supernova_pca_model( coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5] strain = xp.sum( - xp.array([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)]), + xp.asarray([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)]), axis=0 ) diff --git a/bilby/gw/time.py b/bilby/gw/time.py index 996bad070..3c115646b 100644 --- a/bilby/gw/time.py +++ b/bilby/gw/time.py @@ -189,7 +189,7 @@ def n_leap_seconds(gps_time, leap_seconds): @dispatch def n_leap_seconds(gps_time: np.ndarray | float | int): # noqa F811 xp = array_module(gps_time) - return n_leap_seconds(gps_time, xp.array(LEAP_SECONDS)) + return n_leap_seconds(gps_time, xp.asarray(LEAP_SECONDS)) @dispatch diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 013ecc581..13879b37d 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -85,7 +85,7 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): x_comp = (radius + elevation) * xp.cos(latitude) * xp.cos(longitude) y_comp = (radius + elevation) * xp.cos(latitude) * xp.sin(longitude) z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * xp.sin(latitude) - return xp.array([x_comp, y_comp, z_comp]) + return xp.asarray([x_comp, y_comp, z_comp]) def inner_product(aa, bb, frequency, PSD): diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 2c4b65abd..629324a09 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -111,5 +111,5 @@ def resample_posteriors(self, max_samples=None, xp=np): for key in data: data[key].append(temp[key]) for key in data: - data[key] = xp.array(data[key]) + data[key] = xp.asarray(data[key]) return data diff --git a/docs/array_api.rst b/docs/array_api.rst index f3d293b02..6ce77c28a 100644 --- a/docs/array_api.rst +++ b/docs/array_api.rst @@ -351,7 +351,7 @@ All array-processing methods in prior classes follow this pattern: def sample(self, size=None, *, xp=np): """Method that uses xp but isn't wrapped.""" - return xp.array(random.rng.uniform(0, 1, size)) + return xp.asarray(random.rng.uniform(0, 1, size)) Key rules: @@ -399,17 +399,17 @@ backends using the ``array_backend`` marker: class TestMyPrior: def test_prob(self): prior = MyPrior() - val = self.xp.array([0.5, 1.5, 2.5]) + val = self.xp.asarray([0.5, 1.5, 2.5]) # No need to pass xp - automatically detected prob = prior.prob(val) assert self.xp.all(prob >= 0) - assert prob.__array_namespace__() == self.xp + assert aac.get_namespace(prob) == self.xp def test_sample(self): prior = MyPrior() # Sampling requires explicit xp samples = prior.sample(size=100, xp=self.xp) - assert samples.__array_namespace__() == self.xp + assert aac.get_namespace(samples) == self.xp The array_backend Marker ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/conftest.py b/test/conftest.py index f7f5a17a3..2bcd416cd 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,6 @@ import importlib + +import array_api_compat as aac import pytest @@ -38,20 +40,20 @@ def _xp(request): backend = request.config.getoption("--array-backend") match backend: case None | "numpy": - import numpy - return numpy + import numpy as xp case "jax" | "jax.numpy": import os import jax os.environ["SCIPY_ARRAY_API"] = "1" jax.config.update("jax_enable_x64", True) - return jax.numpy + xp = jax.numpy case _: try: - importlib.import_module(backend) + xp = importlib.import_module(backend) except ImportError: raise ValueError(f"Unknown backend for testing: {backend}") + return aac.get_namespace(xp.ones(1)) @pytest.fixture diff --git a/test/core/grid_test.py b/test/core/grid_test.py index d99af8327..009ab2c15 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -14,8 +14,8 @@ class MultiGaussian(bilby.Likelihood): def __init__(self, mean, cov, *, xp=np): super(MultiGaussian, self).__init__() self.xp = xp - self.cov = xp.array(cov) - self.mean = xp.array(mean) + self.cov = xp.asarray(cov) + self.mean = xp.asarray(mean) self.sigma = xp.sqrt(xp.diag(self.cov)) self.logpdf = multivariate_logpdf(xp=xp, mean=self.mean, cov=self.cov) @@ -24,7 +24,7 @@ def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = self.xp.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + x = self.xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) return self.logpdf(x) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 80d1324c8..ab12208c8 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np import pytest import array_api_extra as xpx @@ -172,13 +173,13 @@ def setUp(self): self.N = 100 self.sigma = 0.1 self.x = self.xp.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + self.xp.array(np.random.normal(0, self.sigma, self.N)) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function - self.parameters = dict(m=self.xp.array(2.0), c=self.xp.array(0.0)) + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -225,7 +226,7 @@ def test_repr(self): def test_return_class(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, self.sigma) logl = likelihood.log_likelihood(self.parameters) - self.assertEqual(logl.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(logl), self.xp) @pytest.mark.array_backend @@ -236,13 +237,13 @@ def setUp(self): self.nu = self.N - 2 self.sigma = 1 self.x = self.xp.linspace(0, 1, self.N) - self.y = 2 * self.x + 1 + self.xp.array(np.random.normal(0, self.sigma, self.N)) + self.y = 2 * self.x + 1 + self.xp.asarray(np.random.normal(0, self.sigma, self.N)) def test_function(x, m, c): return m * x + c self.function = test_function - self.parameters = dict(m=self.xp.array(2.0), c=self.xp.array(0.0)) + self.parameters = dict(m=self.xp.asarray(2.0), c=self.xp.asarray(0.0)) def tearDown(self): del self.N @@ -306,7 +307,7 @@ def setUp(self): self.N = 100 self.mu = 5 self.x = self.xp.linspace(0, 1, self.N) - self.y = self.xp.array(np.random.poisson(self.mu, self.N)) + self.y = self.xp.asarray(np.random.poisson(self.mu, self.N)) self.yfloat = self.y.copy() * 1.0 self.yneg = self.y.copy() self.yneg = xpx.at(self.yneg, 0).set(-1) @@ -320,7 +321,7 @@ def test_function_array(x, c): self.function = test_function self.function_array = test_function_array self.poisson_likelihood = PoissonLikelihood(self.x, self.y, self.function) - self.bad_parameters = dict(c=self.xp.array(-2.0)) + self.bad_parameters = dict(c=self.xp.asarray(-2.0)) def tearDown(self): del self.N @@ -381,14 +382,14 @@ def test_log_likelihood_wrong_func_return_type(self): def test_log_likelihood_negative_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: self.xp.array([3, 6, -2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, -2]) ) with self.assertRaises(ValueError): poisson_likelihood.log_likelihood() def test_log_likelihood_zero_func_return_element(self): poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: self.xp.array([3, 6, 0]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([3, 6, 0]) ) self.assertEqual(-np.inf, poisson_likelihood.log_likelihood()) @@ -416,7 +417,7 @@ def setUp(self): self.N = 100 self.mu = 5 self.x = self.xp.linspace(0, 1, self.N) - self.y = self.xp.array(np.random.exponential(self.mu, self.N)) + self.y = self.xp.asarray(np.random.exponential(self.mu, self.N)) self.yneg = self.y.copy() self.yneg = xpx.at(self.yneg, 0).set(-1.0) @@ -431,7 +432,7 @@ def test_function_array(x, c): self.exponential_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=self.function ) - self.bad_parameters = dict(c=self.xp.array(-1.0)) + self.bad_parameters = dict(c=self.xp.asarray(-1.0)) def tearDown(self): del self.N @@ -483,12 +484,12 @@ def test_set_y_to_negative_float(self): def test_set_y_to_nd_array_with_negative_element(self): with self.assertRaises(ValueError): - self.exponential_likelihood.y = self.xp.array([4.3, -1.2, 4]) + self.exponential_likelihood.y = self.xp.asarray([4.3, -1.2, 4]) def test_log_likelihood_default(self): """ Merely tests that it ends up at the right place in the code """ exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=lambda x: self.xp.array([4.2]) + x=self.x, y=self.y, func=lambda x: self.xp.asarray([4.2]) ) with mock.patch(f"{self.xp.__name__}.sum") as m: m.return_value = 3 @@ -509,9 +510,9 @@ def setUp(self): self.sigma = [1, 2, 3] self.mean = [10, 11, 12] if self.xp != np: - self.cov = self.xp.array(self.cov) - self.sigma = self.xp.array(self.sigma) - self.mean = self.xp.array(self.mean) + self.cov = self.xp.asarray(self.cov) + self.sigma = self.xp.asarray(self.sigma) + self.mean = self.xp.asarray(self.mean) self.likelihood = AnalyticalMultidimensionalCovariantGaussian( mean=self.mean, cov=self.cov ) @@ -537,14 +538,14 @@ def test_dim(self): def test_log_likelihood(self): likelihood = AnalyticalMultidimensionalCovariantGaussian( - mean=self.xp.array([0]), cov=self.xp.array([1]) + mean=self.xp.asarray([0]), cov=self.xp.asarray([1]) ) - logl = likelihood.log_likelihood(dict(x0=self.xp.array(0.0))) + logl = likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))) self.assertEqual( -np.log(2 * np.pi) / 2, - likelihood.log_likelihood(dict(x0=self.xp.array(0.0))), + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), ) - self.assertEqual(logl.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(logl), self.xp) @pytest.mark.array_backend @@ -556,10 +557,10 @@ def setUp(self): self.mean_1 = [10, 11, 12] self.mean_2 = [20, 21, 22] if self.xp != np: - self.cov = self.xp.array(self.cov) - self.sigma = self.xp.array(self.sigma) - self.mean_1 = self.xp.array(self.mean_1) - self.mean_2 = self.xp.array(self.mean_2) + self.cov = self.xp.asarray(self.cov) + self.sigma = self.xp.asarray(self.sigma) + self.mean_1 = self.xp.asarray(self.mean_1) + self.mean_2 = self.xp.asarray(self.mean_2) self.likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=self.mean_1, mean_2=self.mean_2, cov=self.cov ) @@ -593,7 +594,7 @@ def test_log_likelihood(self): ) self.assertEqual( -np.log(2 * np.pi) / 2, - likelihood.log_likelihood(dict(x0=self.xp.array(0.0))), + likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))), ) diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 8325be232..dddc0dfba 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -1,5 +1,6 @@ import unittest +import array_api_compat as aac import bilby import numpy as np import pytest @@ -38,23 +39,23 @@ def test_single_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(self.xp.array(1.1)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(2.2)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(300.0)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(0.5)), 0) - self.assertEqual(discrete_value_prior.prob(self.xp.array(200)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) def test_single_probability_unsorted(self): N = 3 values = [1.1, 300, 2.2] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.prob(self.xp.array(1.1)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(2.2)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(300.0)), 1 / N) - self.assertEqual(discrete_value_prior.prob(self.xp.array(0.5)), 0) - self.assertEqual(discrete_value_prior.prob(self.xp.array(200)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(1.1)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(2.2)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(300.0)), 1 / N) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(0.5)), 0) + self.assertEqual(discrete_value_prior.prob(self.xp.asarray(200)), 0) self.assertEqual( - discrete_value_prior.prob(self.xp.array(0.5)).__array_namespace__(), + aac.get_namespace(discrete_value_prior.prob(self.xp.asarray(0.5))), self.xp, ) @@ -64,7 +65,7 @@ def test_array_probability(self): discrete_value_prior = bilby.core.prior.DiscreteValues(values) self.assertTrue( np.all( - discrete_value_prior.prob(self.xp.array([1.1, 2.2, 2.2, 300.0, 200.0])) + discrete_value_prior.prob(self.xp.asarray([1.1, 2.2, 2.2, 300.0, 200.0])) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) ) ) @@ -73,12 +74,12 @@ def test_single_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(1.1)), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(2.2)), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(300)), -np.log(N)) - self.assertEqual(discrete_value_prior.ln_prob(self.xp.array(150)), -np.inf) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(1.1)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(2.2)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(300)), -np.log(N)) + self.assertEqual(discrete_value_prior.ln_prob(self.xp.asarray(150)), -np.inf) self.assertEqual( - discrete_value_prior.ln_prob(self.xp.array(0.5)).__array_namespace__(), + aac.get_namespace(discrete_value_prior.ln_prob(self.xp.asarray(0.5))), self.xp, ) @@ -88,7 +89,7 @@ def test_array_lnprobability(self): discrete_value_prior = bilby.core.prior.DiscreteValues(values) self.assertTrue( np.all( - discrete_value_prior.ln_prob(self.xp.array([1.1, 2.2, 2.2, 300, 150])) + discrete_value_prior.ln_prob(self.xp.asarray([1.1, 2.2, 2.2, 300, 150])) == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) ) ) @@ -111,6 +112,8 @@ def test_array_sample(self): categorical_prior = bilby.core.prior.Categorical(ncat) N = 100000 s = categorical_prior.sample(N, xp=self.xp) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) @@ -120,46 +123,48 @@ def test_array_sample(self): self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) - self.assertEqual(s.__array_namespace__(), self.xp) def test_single_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.prob(self.xp.array(0)), 1 / N) - self.assertEqual(categorical_prior.prob(self.xp.array(1)), 1 / N) - self.assertEqual(categorical_prior.prob(self.xp.array(2)), 1 / N) - self.assertEqual(categorical_prior.prob(self.xp.array(0.5)), 0) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(1)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(2)), 1 / N) + self.assertEqual(categorical_prior.prob(self.xp.asarray(0.5)), 0) self.assertEqual( - categorical_prior.prob(self.xp.array(0.5)).__array_namespace__(), + aac.get_namespace(categorical_prior.prob(self.xp.asarray(0.5))), self.xp, ) def test_array_probability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) + probs = categorical_prior.prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(probs), self.xp) + self.assertTrue(np.all( - categorical_prior.prob(self.xp.array([0, 1, 1, 2, 3])) - == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) + np.asarray(probs) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) )) def test_single_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertEqual(categorical_prior.ln_prob(self.xp.array(0)), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(self.xp.array(1)), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(self.xp.array(2)), -np.log(N)) - self.assertEqual(categorical_prior.ln_prob(self.xp.array(0.5)), -np.inf) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(1)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(2)), -np.log(N)) + self.assertEqual(categorical_prior.ln_prob(self.xp.asarray(0.5)), -np.inf) self.assertEqual( - categorical_prior.ln_prob(self.xp.array(0.5)).__array_namespace__(), + aac.get_namespace(categorical_prior.ln_prob(self.xp.asarray(0.5))), self.xp, ) def test_array_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) + ln_prob = categorical_prior.ln_prob(self.xp.asarray([0, 1, 1, 2, 3])) + self.assertEqual(aac.get_namespace(ln_prob), self.xp) self.assertTrue(np.all( - categorical_prior.ln_prob(self.xp.array([0, 1, 1, 2, 3])) - == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) + np.asarray(ln_prob) == np.array([-np.log(N)] * 4 + [-np.inf]) )) @@ -187,48 +192,50 @@ def test_array_sample(self): categorical_prior = bilby.core.prior.WeightedCategorical(ncat, weights=weights) N = 100000 s = categorical_prior.sample(N, xp=self.xp) + self.assertEqual(aac.get_namespace(s), self.xp) + s = np.asarray(s) cases = 0 - for i in self.xp.array(categorical_prior.values): + for i in categorical_prior.values: case = np.sum(s == i) cases += case self.assertAlmostEqual(case / N, categorical_prior.prob(i), places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(case / N, weights[i] / np.sum(weights), places=int(np.log10(np.sqrt(N)))) self.assertEqual(cases, N) - self.assertEqual(s.__array_namespace__(), self.xp) + def test_single_probability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in self.xp.array(categorical_prior.values): + for i in self.xp.asarray(categorical_prior.values): self.assertEqual(categorical_prior.prob(i), weights[i] / np.sum(weights)) - prob = categorical_prior.prob(self.xp.array(0.5)) + prob = categorical_prior.prob(self.xp.asarray(0.5)) self.assertEqual(prob, 0) - self.assertEqual(prob.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_probability(self): N = 3 - test_cases = self.xp.array([0, 1, 1, 2, 3]) + test_cases = self.xp.asarray([0, 1, 1, 2, 3]) weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) probs = np.arange(1, N + 2) / np.sum(weights) probs[-1] = 0 new = categorical_prior.prob(test_cases) - self.assertTrue(np.all(new == probs[test_cases])) - self.assertEqual(new.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == probs[test_cases])) def test_single_lnprobability(self): N = 3 weights = np.arange(1, N + 1) categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) - for i in self.xp.array(categorical_prior.values): + for i in self.xp.asarray(categorical_prior.values): self.assertEqual( - categorical_prior.ln_prob(self.xp.array(i)), + categorical_prior.ln_prob(self.xp.asarray(i)), np.log(weights[i] / np.sum(weights)), ) - prob = categorical_prior.prob(self.xp.array(0.5)) + prob = categorical_prior.prob(self.xp.asarray(0.5)) self.assertEqual(prob, 0) - self.assertEqual(prob.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_array_lnprobability(self): N = 3 @@ -239,9 +246,9 @@ def test_array_lnprobability(self): ln_probs = np.log(np.arange(1, N + 2) / np.sum(weights)) ln_probs[-1] = -np.inf - new = categorical_prior.ln_prob(self.xp.array(test_cases)) - self.assertTrue(np.all(new == ln_probs[test_cases])) - self.assertEqual(new.__array_namespace__(), self.xp) + new = categorical_prior.ln_prob(self.xp.asarray(test_cases)) + self.assertEqual(aac.get_namespace(new), self.xp) + self.assertTrue(np.all(np.asarray(new) == ln_probs[test_cases])) def test_cdf(self): """ @@ -255,7 +262,7 @@ def test_cdf(self): categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) sample = categorical_prior.sample(size=10) original = self.xp.asarray(sample) - new = self.xp.array(categorical_prior.rescale( + new = self.xp.asarray(categorical_prior.rescale( categorical_prior.cdf(sample) )) np.testing.assert_array_equal(original, new) diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index 21e857fa2..84999d42c 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import Mock +import array_api_compat as aac import numpy as np import pytest @@ -188,7 +189,7 @@ def rescale(self, val): prior = CustomPriorWithoutXp(name="custom_prior") import jax.numpy as jnp rescaled = prior.rescale(jnp.array([0.1, 0.2, 3])) - self.assertEqual(rescaled.__array_namespace__(), jnp) + self.assertEqual(aac.get_namespace(rescaled), jnp) if __name__ == "__main__": diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index cafa0f73e..5850f5ecf 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -3,6 +3,7 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np import pandas as pd import pickle @@ -212,10 +213,10 @@ def condition_func_3(reference_parameters, var_1, var_2): bilby.core.prior.ConditionalPriorDict() ) self.test_sample = dict( - var_0=self.xp.array(0.7), - var_1=self.xp.array(0.6), - var_2=self.xp.array(0.5), - var_3=self.xp.array(0.4), + var_0=self.xp.asarray(0.7), + var_1=self.xp.asarray(0.6), + var_2=self.xp.asarray(0.5), + var_3=self.xp.asarray(0.4), ) self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( @@ -270,7 +271,7 @@ def test_conditional_keys_setting_items(self): def test_prob(self): prob = self.conditional_priors.prob(sample=self.test_sample) self.assertEqual(self.test_value, prob) - self.assertEqual(prob.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_prob_illegal_conditions(self): del self.conditional_priors["var_0"] @@ -366,7 +367,7 @@ def test_rescale_with_joint_prior(self): res = priordict.rescale(keys=keys, theta=ref_variables) self.assertEqual(np.shape(res), (6,)) - self.assertEqual(res.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(res), self.xp) # check conditional values are still as expected expected = [self.test_sample["var_0"]] @@ -470,7 +471,7 @@ def tearDown(self): def test_samples_correct_type(self): samples = self.priors.sample(10, xp=self.xp) - self.assertEqual(samples["dirichlet_1"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(samples["dirichlet_1"]), self.xp) def test_samples_sum_to_less_than_one(self): """ diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 96479326d..3101d181a 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -2,6 +2,7 @@ import unittest from unittest.mock import Mock, patch +import array_api_compat as aac import numpy as np import pytest @@ -277,7 +278,7 @@ def test_sample_subset_correct_size(self): self.assertEqual(len(self.prior_set_from_dict), len(samples)) for key in samples: self.assertEqual(size, len(samples[key])) - self.assertEqual(samples[key].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_correct_size_when_non_priors_in_dict(self): self.prior_set_from_dict["asdf"] = "not_a_prior" @@ -287,16 +288,16 @@ def test_sample_subset_correct_size_when_non_priors_in_dict(self): ) self.assertEqual(len(self.prior_set_from_dict) - 1, len(samples)) for key in samples: - self.assertIsNotNone(samples[key].__array_namespace__(), self.xp) + self.assertIsNotNone(aac.get_namespace(samples[key]), self.xp) def test_sample_subset_with_actual_subset(self): size = 3 samples = self.prior_set_from_dict.sample_subset( keys=["length"], size=size, xp=self.xp ) - expected = dict(length=self.xp.array([42.0, 42.0, 42.0])) + expected = dict(length=self.xp.asarray([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) - self.assertEqual(samples["length"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(samples["length"]), self.xp) def test_sample_subset_constrained_as_array(self): size = 3 @@ -304,7 +305,7 @@ def test_sample_subset_constrained_as_array(self): out = self.prior_set_from_dict.sample_subset_constrained_as_array( keys, size, xp=self.xp ) - self.assertEqual(out.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(out), self.xp) self.assertTrue(out.shape == (len(keys), size)) def test_sample_subset_constrained(self): @@ -356,8 +357,8 @@ def test_sample(self): self.assertEqual(set(samples1.keys()), set(samples2.keys())) for key in samples1: self.assertTrue(np.array_equal(samples1[key], samples2[key])) - self.assertEqual(samples1[key].__array_namespace__(), self.xp) - self.assertEqual(samples2[key].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(samples1[key]), self.xp) + self.assertEqual(aac.get_namespace(samples2[key]), self.xp) def test_prob(self): samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) @@ -365,7 +366,7 @@ def test_prob(self): samples["speed"] ) self.assertEqual(expected, self.prior_set_from_dict.prob(samples)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_ln_prob(self): samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"], xp=self.xp) @@ -373,7 +374,7 @@ def test_ln_prob(self): samples["mass"] ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(expected), self.xp) def test_rescale(self): theta = [0.5, 0.5, 0.5] @@ -398,13 +399,13 @@ def test_cdf(self): Note that the format of inputs/outputs is different between the two methods. """ sample = self.prior_set_from_dict.sample(xp=self.xp) - original = self.xp.array(list(sample.values())) - new = self.xp.array(self.prior_set_from_dict.rescale( + original = self.xp.asarray(list(sample.values())) + new = self.xp.asarray(self.prior_set_from_dict.rescale( sample.keys(), self.prior_set_from_dict.cdf(sample=sample).values() )) self.assertLess(max(abs(original - new)), 1e-10) - self.assertEqual(new.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(new), self.xp) def test_redundancy(self): for key in self.prior_set_from_dict.keys(): diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index c3fa1e865..c24ae1118 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -1,3 +1,4 @@ +import array_api_compat as aac import bilby import unittest import numpy as np @@ -268,27 +269,27 @@ def test_minimum_rescaling(self): # and so the rescale function doesn't quite return the lower bound continue elif bilby.core.prior.JointPrior in prior.__class__.__mro__: - minimum_sample = prior.rescale(self.xp.array(0)) + minimum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): self.assertAlmostEqual(minimum_sample[0], prior.minimum) self.assertAlmostEqual(minimum_sample[1], prior.minimum) else: - minimum_sample = prior.rescale(self.xp.array(0)) + minimum_sample = prior.rescale(self.xp.asarray(0)) self.assertAlmostEqual(minimum_sample, prior.minimum) def test_maximum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: if bilby.core.prior.JointPrior in prior.__class__.__mro__: - maximum_sample = prior.rescale(self.xp.array(0)) + maximum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): self.assertAlmostEqual(maximum_sample[0], prior.maximum) self.assertAlmostEqual(maximum_sample[1], prior.maximum) elif isinstance(prior, bilby.gw.prior.AlignedSpin): - maximum_sample = prior.rescale(self.xp.array(1)) + maximum_sample = prior.rescale(self.xp.asarray(1)) self.assertGreater(maximum_sample, 0.997) else: - maximum_sample = prior.rescale(self.xp.array(1)) + maximum_sample = prior.rescale(self.xp.asarray(1)) self.assertAlmostEqual(maximum_sample, prior.maximum) def test_many_sample_rescaling(self): @@ -297,14 +298,14 @@ def test_many_sample_rescaling(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - many_samples = prior.rescale(self.xp.array(np.random.uniform(0, 1, 1000))) + many_samples = prior.rescale(self.xp.asarray(np.random.uniform(0, 1, 1000))) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): continue self.assertTrue( all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) ) - self.assertEqual(many_samples.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(many_samples), self.xp) def test_least_recently_sampled(self): for prior in self.priors: @@ -312,7 +313,7 @@ def test_least_recently_sampled(self): self.assertEqual( least_recently_sampled_expected, prior.least_recently_sampled ) - self.assertEqual(least_recently_sampled_expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(least_recently_sampled_expected), self.xp) def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" @@ -324,7 +325,7 @@ def test_sampling_single(self): self.assertTrue( (single_sample >= prior.minimum) & (single_sample <= prior.maximum) ) - self.assertEqual(single_sample.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(single_sample), self.xp) def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" @@ -337,7 +338,7 @@ def test_sampling_many(self): (all(many_samples >= prior.minimum)) & (all(many_samples <= prior.maximum)) ) - self.assertEqual(many_samples.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(many_samples), self.xp) def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" @@ -372,7 +373,7 @@ def test_least_recently_sampled_2(self): for prior in self.priors: lrs = prior.sample(xp=self.xp) self.assertEqual(lrs, prior.least_recently_sampled) - self.assertEqual(lrs.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(lrs), self.xp) def test_prob_and_ln_prob(self): for prior in self.priors: @@ -386,8 +387,8 @@ def test_prob_and_ln_prob(self): self.assertAlmostEqual( self.xp.log(prob), lnprob, 6 ) - self.assertEqual(lnprob.__array_namespace__(), self.xp) - self.assertEqual(prob.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(lnprob), self.xp) + self.assertEqual(aac.get_namespace(prob), self.xp) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: @@ -398,8 +399,8 @@ def test_many_prob_and_many_ln_prob(self): for sample, logp, p in zip(samples, ln_probs, probs): self.assertAlmostEqual(prior.ln_prob(sample), logp) self.assertAlmostEqual(prior.prob(sample), p) - self.assertEqual(ln_probs.__array_namespace__(), self.xp) - self.assertEqual(probs.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(ln_probs), self.xp) + self.assertEqual(aac.get_namespace(probs), self.xp) def test_cdf_is_inverse_of_rescaling(self): domain = self.xp.linspace(0, 1, 100) @@ -421,11 +422,11 @@ def test_cdf_is_inverse_of_rescaling(self): self.assertTrue(np.array_equal(rescaled, rescaled_2)) max_difference = max(np.abs(cdf_vals - cdf_vals_2)) for arr in [rescaled, rescaled_2, cdf_vals, cdf_vals_2]: - self.assertEqual(arr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(arr), self.xp) else: rescaled = prior.rescale(domain) max_difference = max(np.abs(domain - prior.cdf(rescaled))) - self.assertEqual(rescaled.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(rescaled), self.xp) self.assertLess(max_difference, threshold) def test_cdf_one_above_domain(self): @@ -669,13 +670,13 @@ def test_normalized(self): domain = np.linspace(prior.minimum, prior.maximum, 10000) elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): domain = prior.values - self.assertTrue(np.sum(prior.prob(self.xp.array(domain))) == 1) + self.assertTrue(np.sum(prior.prob(self.xp.asarray(domain))) == 1) continue else: domain = np.linspace(prior.minimum, prior.maximum, 1000) - probs = prior.prob(self.xp.array(domain)) + probs = prior.prob(self.xp.asarray(domain)) self.assertAlmostEqual(trapezoid(np.array(probs), domain), 1, 3) - self.assertEqual(probs.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(probs), self.xp) def test_accuracy(self): """Test that each of the priors' functions is calculated accurately, as compared to scipy's calculations""" @@ -771,14 +772,14 @@ def test_accuracy(self): ) if isinstance(prior, (testTuple)): print(prior) - np.testing.assert_almost_equal(prior.prob(self.xp.array(domain)), scipy_prob) - np.testing.assert_almost_equal(prior.ln_prob(self.xp.array(domain)), scipy_lnprob) - np.testing.assert_almost_equal(prior.cdf(self.xp.array(domain)), scipy_cdf) + np.testing.assert_almost_equal(prior.prob(self.xp.asarray(domain)), scipy_prob) + np.testing.assert_almost_equal(prior.ln_prob(self.xp.asarray(domain)), scipy_lnprob) + np.testing.assert_almost_equal(prior.cdf(self.xp.asarray(domain)), scipy_cdf) if isinstance(prior, bilby.core.prior.StudentT) and "jax" in str(self.xp): # JAX implementation of StudentT prior rescale is not accurate enough continue np.testing.assert_almost_equal( - prior.rescale(self.xp.array(rescale_domain)), scipy_rescale + prior.rescale(self.xp.asarray(rescale_domain)), scipy_rescale ) def test_unit_setting(self): diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index cd066ef05..c8a07aae8 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -70,7 +70,7 @@ class TestSlabSpikeClasses(unittest.TestCase): def setUp(self): self.minimum = 0.4 self.maximum = 2.4 - self.spike_loc = self.xp.array(1.5) + self.spike_loc = self.xp.asarray(1.5) self.spike_height = 0.3 self.slabs = [ @@ -114,7 +114,7 @@ def test_prob_on_slab(self): expected = slab.prob(test_nodes) * slab_spike.slab_fraction actual = slab_spike.prob(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_prob_on_spike(self): for slab_spike in self.slab_spikes: @@ -125,13 +125,13 @@ def test_ln_prob_on_slab(self): expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction) actual = slab_spike.ln_prob(test_nodes) self.assertTrue(np.array_equal(expected, actual)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_ln_prob_on_spike(self): for slab_spike in self.slab_spikes: expected = slab_spike.ln_prob(self.spike_loc) self.assertEqual(np.inf, expected) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_inverse_cdf_below_spike_with_spike_at_minimum(self): for slab in self.slabs: @@ -154,14 +154,14 @@ def test_cdf_below_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction actual = slab_spike.cdf(test_nodes) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(actual.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction actual = slab_spike.cdf(self.xp.asarray(self.spike_loc)) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(actual.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_above_spike(self): for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): @@ -169,7 +169,7 @@ def test_cdf_above_spike(self): expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height actual = slab_spike.cdf(test_nodes) np.testing.assert_allclose(expected, actual, rtol=1e-12) - self.assertEqual(actual.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_cdf_at_minimum(self): for slab_spike in self.slab_spikes: @@ -190,7 +190,7 @@ def test_rescale_no_spike(self): expected = slab.rescale(vals) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_below_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): @@ -198,7 +198,7 @@ def test_rescale_below_spike(self): expected = slab.rescale(vals / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): @@ -209,7 +209,7 @@ def test_rescale_at_spike(self): expected = self.xp.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_rescale_above_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): @@ -218,4 +218,4 @@ def test_rescale_above_spike(self): (vals - self.spike_height) / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) - self.assertEqual(expected.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) diff --git a/test/core/result_test.py b/test/core/result_test.py index 64fc9c14f..d36f03190 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -30,7 +30,7 @@ def test_list_encoding(self): self.assertTrue(np.all(data["x"] == decoded["x"])) def test_array_encoding(self): - data = dict(x=self.xp.array([1, 2, 3.4])) + data = dict(x=self.xp.asarray([1, 2, 3.4])) encoded = json.dumps(data, cls=self.encoder) decoded = json.loads(encoded, object_hook=self.decoder) self.assertEqual(data.keys(), decoded.keys()) diff --git a/test/core/series_test.py b/test/core/series_test.py index 7b85c2bc5..aec2ff42f 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -11,9 +11,9 @@ @pytest.mark.usefixtures("xp_class") class TestCoupledTimeAndFrequencySeries(unittest.TestCase): def setUp(self): - self.duration = self.xp.array(2.0) - self.sampling_frequency = self.xp.array(4096.0) - self.start_time = self.xp.array(-1.0) + self.duration = self.xp.asarray(2.0) + self.sampling_frequency = self.xp.asarray(4096.0) + self.start_time = self.xp.asarray(-1.0) self.series = CoupledTimeAndFrequencySeries( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -67,8 +67,8 @@ def test_time_array_from_init(self): self.assertTrue(np.array_equal(expected, self.series.time_array)) def test_frequency_array_setter(self): - new_sampling_frequency = self.xp.array(100.0) - new_duration = self.xp.array(3.0) + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) new_frequency_array = create_frequency_series( sampling_frequency=new_sampling_frequency, duration=new_duration ) @@ -83,9 +83,9 @@ def test_frequency_array_setter(self): self.assertAlmostEqual(self.start_time, self.series.start_time) def test_time_array_setter(self): - new_sampling_frequency = self.xp.array(100.0) - new_duration = self.xp.array(3.0) - new_start_time = self.xp.array(4.0) + new_sampling_frequency = self.xp.asarray(100.0) + new_duration = self.xp.asarray(3.0) + new_start_time = self.xp.asarray(4.0) new_time_array = create_time_series( sampling_frequency=new_sampling_frequency, duration=new_duration, @@ -101,24 +101,24 @@ def test_time_array_setter(self): def test_time_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = self.xp.array(4) + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.time_array def test_time_array_without_duration(self): - self.series.sampling_frequency = self.xp.array(4096) + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.time_array def test_frequency_array_without_sampling_frequency(self): self.series.sampling_frequency = None - self.series.duration = self.xp.array(4) + self.series.duration = self.xp.asarray(4) with self.assertRaises(ValueError): _ = self.series.frequency_array def test_frequency_array_without_duration(self): - self.series.sampling_frequency = self.xp.array(4096) + self.series.sampling_frequency = self.xp.asarray(4096) self.series.duration = None with self.assertRaises(ValueError): _ = self.series.frequency_array diff --git a/test/core/utils_test.py b/test/core/utils_test.py index fdd4afeef..bd1869a6a 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -1,6 +1,7 @@ import unittest import os +import array_api_compat as aac import array_api_extra as xpx import dill import numpy as np @@ -54,16 +55,16 @@ def test_gravitational_constant(self): @pytest.mark.usefixtures("xp_class") class TestFFT(unittest.TestCase): def setUp(self): - self.sampling_frequency = self.xp.array(10) + self.sampling_frequency = self.xp.asarray(10) def tearDown(self): del self.sampling_frequency def test_nfft_sine_function(self): xp = self.xp - injected_frequency = xp.array(2.7324) - duration = xp.array(100) - times = utils.create_time_series(xp.array(self.sampling_frequency), duration) + injected_frequency = xp.asarray(2.7324) + duration = xp.asarray(100) + times = utils.create_time_series(xp.asarray(self.sampling_frequency), duration) time_domain_strain = xp.sin(2 * np.pi * times * injected_frequency + 0.4) @@ -75,7 +76,7 @@ def test_nfft_sine_function(self): def test_nfft_infft(self): xp = self.xp - time_domain_strain = xp.array(np.random.normal(0, 1, 10)) + time_domain_strain = xp.asarray(np.random.normal(0, 1, 10)) frequency_domain_strain, _ = bilby.core.utils.nfft( time_domain_strain, self.sampling_frequency ) @@ -128,9 +129,9 @@ def test_self_handling_method_as_function(self): @pytest.mark.usefixtures("xp_class") class TestTimeAndFrequencyArrays(unittest.TestCase): def setUp(self): - self.start_time = self.xp.array(1.3) - self.sampling_frequency = self.xp.array(5) - self.duration = self.xp.array(1.6) + self.start_time = self.xp.asarray(1.3) + self.sampling_frequency = self.xp.asarray(5) + self.duration = self.xp.asarray(1.6) self.frequency_array = utils.create_frequency_series( sampling_frequency=self.sampling_frequency, duration=self.duration ) @@ -148,13 +149,13 @@ def tearDown(self): del self.time_array def test_create_time_array(self): - expected_time_array = self.xp.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) + expected_time_array = self.xp.asarray([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7]) time_array = utils.create_time_series( sampling_frequency=self.sampling_frequency, duration=self.duration, starting_time=self.start_time, ) - self.assertEqual(time_array.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(time_array), self.xp) self.assertTrue(np.allclose(expected_time_array, time_array)) def test_create_frequency_array(self): @@ -243,9 +244,9 @@ def test_consistency_frequency_array_to_frequency_array(self): def test_illegal_sampling_frequency_and_duration(self): with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): _ = utils.create_time_series( - sampling_frequency=self.xp.array(7.7), - duration=self.xp.array(1.3), - starting_time=self.xp.array(0), + sampling_frequency=self.xp.asarray(7.7), + duration=self.xp.asarray(1.3), + starting_time=self.xp.asarray(0), ) @@ -253,28 +254,28 @@ def test_illegal_sampling_frequency_and_duration(self): @pytest.mark.usefixtures("xp_class") class TestReflect(unittest.TestCase): def test_in_range(self): - xprime = self.xp.array([0.1, 0.5, 0.9]) - x = self.xp.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([0.1, 0.5, 0.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_one_to_two(self): - xprime = self.xp.array([1.1, 1.5, 1.9]) - x = self.xp.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([1.1, 1.5, 1.9]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_two_to_three(self): - xprime = self.xp.array([2.1, 2.5, 2.9]) - x = self.xp.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([2.1, 2.5, 2.9]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_one_to_zero(self): - xprime = self.xp.array([-0.9, -0.5, -0.1]) - x = self.xp.array([0.9, 0.5, 0.1]) + xprime = self.xp.asarray([-0.9, -0.5, -0.1]) + x = self.xp.asarray([0.9, 0.5, 0.1]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) def test_in_minus_two_to_minus_one(self): - xprime = self.xp.array([-1.9, -1.5, -1.1]) - x = self.xp.array([0.1, 0.5, 0.9]) + xprime = self.xp.asarray([-1.9, -1.5, -1.1]) + x = self.xp.asarray([0.1, 0.5, 0.9]) self.assertTrue(np.testing.assert_allclose(utils.reflect(xprime), x) is None) @@ -359,23 +360,23 @@ def test_returns_none_for_floats_outside_range(self): self.assertIsNone(self.interpolant(-0.5, 0.5)) def test_returns_float_for_float_and_array(self): - input_array = self.xp.array(np.random.random(10)) - self.assertEqual(self.interpolant(input_array, 0.5).__array_namespace__(), self.xp) - self.assertEqual( - self.interpolant(input_array, input_array).__array_namespace__(), self.xp + input_array = self.xp.asarray(np.random.random(10)) + aac.get_namespace(self.assertEqual(self.interpolant(input_array, 0.5)), self.xp) + aac.get_namespace(self.assertEqual( + self.interpolant(input_array, input_array)), self.xp ) - self.assertEqual(self.interpolant(0.5, input_array).__array_namespace__(), self.xp) + aac.get_namespace(self.assertEqual(self.interpolant(0.5, input_array)), self.xp) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): self.interpolant( - self.xp.array(np.random.random(10)), - self.xp.array(np.random.random(20)), + self.xp.asarray(np.random.random(10)), + self.xp.asarray(np.random.random(20)), ) def test_returns_fill_in_correct_place(self): - x_data = self.xp.array(np.random.random(10)) - y_data = self.xp.array(np.random.random(10)) + x_data = self.xp.asarray(np.random.random(10)) + y_data = self.xp.asarray(np.random.random(10)) x_data = xpx.at(x_data, 3).set(-1) self.assertTrue(self.xp.isnan(self.interpolant(x_data, y_data)[3])) @@ -394,7 +395,7 @@ def setUp(self): self.lnfunc2 = self.xp.log(self.x ** 2) self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 - self.irregularx = self.xp.array( + self.irregularx = self.xp.asarray( [ self.x[0], self.x[12], diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index 4961c11e0..d0ce869b7 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -1,6 +1,7 @@ import unittest from copy import deepcopy +import array_api_compat as aac import numpy as np import pandas as pd import pytest @@ -13,17 +14,17 @@ @pytest.mark.usefixtures("xp_class") class TestBasicConversions(unittest.TestCase): def setUp(self): - self.mass_1 = self.xp.array(1.4) - self.mass_2 = self.xp.array(1.3) - self.mass_ratio = self.xp.array(13 / 14) - self.total_mass = self.xp.array(2.7) + self.mass_1 = self.xp.asarray(1.4) + self.mass_2 = self.xp.asarray(1.3) + self.mass_ratio = self.xp.asarray(13 / 14) + self.total_mass = self.xp.asarray(2.7) self.chirp_mass = (self.mass_1 * self.mass_2) ** 0.6 / self.total_mass ** 0.2 self.symmetric_mass_ratio = (self.mass_1 * self.mass_2) / self.total_mass ** 2 - self.cos_angle = self.xp.array(-1.0) + self.cos_angle = self.xp.asarray(-1.0) self.angle = self.xp.pi - self.lambda_1 = self.xp.array(300.0) - self.lambda_2 = self.xp.array(300.0 * (14 / 13) ** 5) - self.lambda_tilde = self.xp.array( + self.lambda_1 = self.xp.asarray(300.0) + self.lambda_2 = self.xp.asarray(300.0 * (14 / 13) ** 5) + self.lambda_tilde = self.xp.asarray( 8 / 13 * ( @@ -42,7 +43,7 @@ def setUp(self): * (self.lambda_1 - self.lambda_2) ) ) - self.delta_lambda_tilde = self.xp.array( + self.delta_lambda_tilde = self.xp.asarray( 1 / 2 * ( @@ -78,36 +79,36 @@ def test_total_mass_and_mass_ratio_to_component_masses(self): self.assertTrue( all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5]) ) - self.assertEqual(mass_1.__array_namespace__(), self.xp) - self.assertEqual(mass_2.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_chirp_mass_and_primary_mass_to_mass_ratio(self): mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( self.chirp_mass, self.mass_1 ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) - self.assertEqual(mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_symmetric_mass_ratio_to_mass_ratio(self): mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( self.symmetric_mass_ratio ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) - self.assertEqual(mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.chirp_mass_and_total_mass_to_symmetric_mass_ratio( self.chirp_mass, self.total_mass ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) - self.assertEqual(symmetric_mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_chirp_mass_and_mass_ratio_to_total_mass(self): total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( self.chirp_mass, self.mass_ratio ) self.assertAlmostEqual(self.total_mass, total_mass) - self.assertEqual(total_mass.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_chirp_mass_and_mass_ratio_to_component_masses(self): mass_1, mass_2 = \ @@ -115,37 +116,37 @@ def test_chirp_mass_and_mass_ratio_to_component_masses(self): self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.mass_1, mass_1) self.assertAlmostEqual(self.mass_2, mass_2) - self.assertEqual(mass_1.__array_namespace__(), self.xp) - self.assertEqual(mass_2.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_1), self.xp) + self.assertEqual(aac.get_namespace(mass_2), self.xp) def test_component_masses_to_chirp_mass(self): chirp_mass = conversion.component_masses_to_chirp_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.chirp_mass, chirp_mass) - self.assertEqual(chirp_mass.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(chirp_mass), self.xp) def test_component_masses_to_total_mass(self): total_mass = conversion.component_masses_to_total_mass(self.mass_1, self.mass_2) self.assertAlmostEqual(self.total_mass, total_mass) - self.assertEqual(total_mass.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_component_masses_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio( self.mass_1, self.mass_2 ) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) - self.assertEqual(symmetric_mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_component_masses_to_mass_ratio(self): mass_ratio = conversion.component_masses_to_mass_ratio(self.mass_1, self.mass_2) self.assertAlmostEqual(self.mass_ratio, mass_ratio) - self.assertEqual(mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_mass_1_and_chirp_mass_to_mass_ratio(self): mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio( self.mass_1, self.chirp_mass ) self.assertAlmostEqual(self.mass_ratio, mass_ratio) - self.assertEqual(mass_ratio.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_lambda_tilde_to_lambda_1_lambda_2(self): lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2( @@ -159,8 +160,8 @@ def test_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) - self.assertEqual(lambda_1.__array_namespace__(), self.xp) - self.assertEqual(lambda_2.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ( @@ -177,22 +178,22 @@ def test_lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(self): ] ) ) - self.assertEqual(lambda_1.__array_namespace__(), self.xp) - self.assertEqual(lambda_2.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(lambda_1), self.xp) + self.assertEqual(aac.get_namespace(lambda_2), self.xp) def test_lambda_1_lambda_2_to_lambda_tilde(self): lambda_tilde = conversion.lambda_1_lambda_2_to_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.lambda_tilde - lambda_tilde) < 1e-5) - self.assertEqual(lambda_tilde.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(lambda_tilde), self.xp) def test_lambda_1_lambda_2_to_delta_lambda_tilde(self): delta_lambda_tilde = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( self.lambda_1, self.lambda_2, self.mass_1, self.mass_2 ) self.assertTrue((self.delta_lambda_tilde - delta_lambda_tilde) < 1e-5) - self.assertEqual(delta_lambda_tilde.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(delta_lambda_tilde), self.xp) def test_identity_conversion(self): original_samples = dict( @@ -630,16 +631,16 @@ def test_comoving_luminosity_with_cosmology(self): @pytest.mark.usefixtures("xp_class") class TestGenerateMassParameters(unittest.TestCase): def setUp(self): - self.expected_values = {'mass_1': self.xp.array(2.0), - 'mass_2': self.xp.array(1.0), - 'chirp_mass': self.xp.array(1.2167286837864113), - 'total_mass': self.xp.array(3.0), - 'mass_1_source': self.xp.array(4.0), - 'mass_2_source': self.xp.array(2.0), - 'chirp_mass_source': self.xp.array(2.433457367572823), - 'total_mass_source': self.xp.array(6), - 'symmetric_mass_ratio': self.xp.array(0.2222222222222222), - 'mass_ratio': self.xp.array(0.5)} + self.expected_values = {'mass_1': self.xp.asarray(2.0), + 'mass_2': self.xp.asarray(1.0), + 'chirp_mass': self.xp.asarray(1.2167286837864113), + 'total_mass': self.xp.asarray(3.0), + 'mass_1_source': self.xp.asarray(4.0), + 'mass_2_source': self.xp.asarray(2.0), + 'chirp_mass_source': self.xp.asarray(2.433457367572823), + 'total_mass_source': self.xp.asarray(6), + 'symmetric_mass_ratio': self.xp.asarray(0.2222222222222222), + 'mass_ratio': self.xp.asarray(0.5)} def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses @@ -688,7 +689,7 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): for key in local_all_mass_parameters.keys(): self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) self.assertEqual( - local_all_mass_parameters[key].__array_namespace__(), + aac.get_namespace(local_all_mass_parameters[key]), self.xp, ) @@ -767,42 +768,42 @@ class TestEquationOfStateConversions(unittest.TestCase): ''' def setUp(self): - self.mass_1_source_spectral = self.xp.array([ + self.mass_1_source_spectral = self.xp.asarray([ 4.922542724434885, 4.350626907771598, 4.206155335439082, 1.7822696459661311, 1.3091740103047926 ]) - self.mass_2_source_spectral = self.xp.array([ + self.mass_2_source_spectral = self.xp.asarray([ 3.459974694590303, 1.2276461777181447, 3.7287707089639976, 0.3724016563531846, 1.055042934805801 ]) - self.spectral_pca_gamma_0 = self.xp.array([ + self.spectral_pca_gamma_0 = self.xp.asarray([ 0.7074873121348357, 0.05855931126849878, 0.7795329261793462, 1.467907561566463, 2.9066488405635624 ]) - self.spectral_pca_gamma_1 = self.xp.array([ + self.spectral_pca_gamma_1 = self.xp.asarray([ -0.29807111670823816, 2.027708558522935, -1.4415775226512115, -0.7104870098896858, -0.4913817181089619 ]) - self.spectral_pca_gamma_2 = self.xp.array([ + self.spectral_pca_gamma_2 = self.xp.asarray([ 0.25625095371021156, -0.19574096643220049, -0.2710238103460012, 0.22815820981582358, -0.1543413205016374 ]) - self.spectral_pca_gamma_3 = self.xp.array([ + self.spectral_pca_gamma_3 = self.xp.asarray([ -0.04030365100175101, 0.05698030777919032, -0.045595911403040264, @@ -914,7 +915,7 @@ def test_spectral_pca_to_spectral(self): self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) for val in [spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3]: - self.assertEqual(val.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(val), self.xp) def test_spectral_params_to_lambda_1_lambda_2(self): ''' diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 4906f00cc..b18eb218d 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import array_api_compat as aac import numpy as np import pytest @@ -44,35 +45,35 @@ def tearDown(self): def test_length_setting(self): self.assertEqual(self.geometry.length, self.length) - self.assertEqual(self.geometry.length.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(length), self.xp) def test_latitude_setting(self): self.assertEqual(self.geometry.latitude, self.latitude) - self.assertEqual(self.geometry.latitude.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(latitude), self.xp) def test_longitude_setting(self): self.assertEqual(self.geometry.longitude, self.longitude) - self.assertEqual(self.geometry.longitude.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(longitude), self.xp) def test_elevation_setting(self): self.assertEqual(self.geometry.elevation, self.elevation) - self.assertEqual(self.geometry.elevation.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(elevation), self.xp) def test_xarm_azi_setting(self): self.assertEqual(self.geometry.xarm_azimuth, self.xarm_azimuth) - self.assertEqual(self.geometry.xarm_azimuth.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(xarm_azimuth), self.xp) def test_yarm_azi_setting(self): self.assertEqual(self.geometry.yarm_azimuth, self.yarm_azimuth) - self.assertEqual(self.geometry.yarm_azimuth.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(yarm_azimuth), self.xp) def test_xarm_tilt_setting(self): self.assertEqual(self.geometry.xarm_tilt, self.xarm_tilt) - self.assertEqual(self.geometry.xarm_tilt.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(xarm_tilt), self.xp) def test_yarm_tilt_setting(self): self.assertEqual(self.geometry.yarm_tilt, self.yarm_tilt) - self.assertEqual(self.geometry.yarm_tilt.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(yarm_tilt), self.xp) def test_vertex_without_update(self): _ = self.geometry.vertex @@ -154,37 +155,37 @@ def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.detector_tensor.__array_namespace__(), self.xp) + self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) def test_unit_vector_along_arm_default(self): with self.assertRaises(ValueError): @@ -198,7 +199,7 @@ def test_unit_vector_along_arm_x(self): self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("x") self.assertTrue(np.allclose(arm, np.array([0, 1, 0]))) - self.assertEqual(arm.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_unit_vector_along_arm_y(self): self.geometry.longitude = 0 @@ -208,7 +209,7 @@ def test_unit_vector_along_arm_y(self): self.geometry.set_array_backend(self.xp) arm = self.geometry.unit_vector_along_arm("y") self.assertTrue(np.allclose(arm, np.array([0, 0, 1]))) - self.assertEqual(arm.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(arm), self.xp) def test_repr(self): expected = ( diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 6acf60757..35f4de99b 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -7,6 +7,7 @@ import pytest from copy import deepcopy +import array_api_compat as aac import h5py import numpy as np import bilby @@ -27,7 +28,7 @@ def __getattr__(self, name): def convert_nested_dict(self, data): if is_array_api_obj(data): - return self.xp.array(data) + return self.xp.asarray(data) elif isinstance(data, dict): return {key: self.convert_nested_dict(value) for key, value in data.items()} else: @@ -70,8 +71,8 @@ def setUp(self): ) self.interferometers.set_array_backend(self.xp) base_wfg = bilby.gw.waveform_generator.GWSignalWaveformGenerator( - duration=self.xp.array(4.0), - sampling_frequency=self.xp.array(2048.0), + duration=self.xp.asarray(4.0), + sampling_frequency=self.xp.asarray(2048.0), waveform_arguments=dict(waveform_approximant="IMRPhenomPv2"), ) self.waveform_generator = BackendWaveformGenerator(base_wfg, self.xp) @@ -94,13 +95,13 @@ def test_noise_log_likelihood(self): self.assertAlmostEqual( -4014.1787704539474, nll, 3 ) - self.assertEqual(nll.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(nll), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" logl = self.likelihood.log_likelihood(self.parameters) self.assertAlmostEqual(logl, -4032.4397343470005, 3) - self.assertEqual(logl.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" @@ -110,7 +111,7 @@ def test_log_likelihood_ratio(self): llr, 3, ) - self.assertEqual(llr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -130,8 +131,8 @@ def test_repr(self): class TestGWTransient(unittest.TestCase): def setUp(self): bilby.core.utils.random.seed(500) - self.duration = self.xp.array(4.0) - self.sampling_frequency = self.xp.array(2048.0) + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) self.parameters = dict( mass_1=31.0, mass_2=29.0, @@ -186,13 +187,13 @@ def test_noise_log_likelihood(self): self.assertAlmostEqual( -4014.1787704539474, nll, 3 ) - self.assertEqual(nll.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(nll), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" logl = self.likelihood.log_likelihood(self.parameters) self.assertAlmostEqual(logl, -4032.4397343470005, 3) - self.assertEqual(logl.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" @@ -202,7 +203,7 @@ def test_log_likelihood_ratio(self): llr, 3, ) - self.assertEqual(llr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr), self.xp) def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the @@ -282,16 +283,16 @@ def test_reference_frame_agrees_with_default(self): ) parameters = self.parameters.copy() del parameters["ra"], parameters["dec"] - parameters["zenith"] = self.xp.array(1.0) - parameters["azimuth"] = self.xp.array(1.0) + parameters["zenith"] = self.xp.asarray(1.0) + parameters["azimuth"] = self.xp.asarray(1.0) parameters["ra"], parameters["dec"] = bilby.gw.utils.zenith_azimuth_to_ra_dec( zenith=parameters["zenith"], azimuth=parameters["azimuth"], geocent_time=parameters["geocent_time"], ifos=new_likelihood.reference_frame, ) - self.assertEqual(parameters["ra"].__array_namespace__(), self.xp) - self.assertEqual(parameters["dec"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(parameters["ra"]), self.xp) + self.assertEqual(aac.get_namespace(parameters["dec"]), self.xp) self.assertEqual( new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) @@ -335,8 +336,8 @@ def test_time_reference_agrees_with_default(self): @pytest.mark.usefixtures("xp_class") class TestROQLikelihood(unittest.TestCase): def setUp(self): - self.duration = self.xp.array(4.0) - self.sampling_frequency = self.xp.array(2048.0) + self.duration = self.xp.asarray(4.0) + self.sampling_frequency = self.xp.asarray(2048.0) self.test_parameters = dict( mass_1=36.0, @@ -483,7 +484,7 @@ def test_matches_non_roq(self): ) / self.non_roq.log_likelihood_ratio(self.test_parameters), 1e-3, ) - self.assertEqual(roq_llr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(roq_llr), self.xp) self.non_roq.parameters.update(self.test_parameters) self.roq.parameters.update(self.test_parameters) self.assertLess( @@ -513,7 +514,7 @@ def test_create_roq_weights_with_params(self): roq_llr, self.roq.log_likelihood_ratio(self.test_parameters) ) - self.assertEqual(roq_llr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(roq_llr), self.xp) roq.parameters.update(self.test_parameters) self.roq.parameters.update(self.test_parameters) self.assertEqual(roq.log_likelihood_ratio(), self.roq.log_likelihood_ratio()) @@ -715,9 +716,9 @@ class TestROQLikelihoodHDF5(unittest.TestCase): _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5" def setUp(self): - self.minimum_frequency = self.xp.array(20.0) - self.sampling_frequency = self.xp.array(2048.0) - self.duration = self.xp.array(16.0) + self.minimum_frequency = self.xp.asarray(20.0) + self.sampling_frequency = self.xp.asarray(2048.0) + self.duration = self.xp.asarray(16.0) self.reference_frequency = 20.0 self.waveform_approximant = "IMRPhenomD" # The SNRs of injections are 130-160 for roq_scale_factor=1 and 70-80 for roq_scale_factor=2 @@ -759,9 +760,9 @@ def test_fails_with_frequency_duration_mismatch( self.priors["chirp_mass"].maximum = 9 interferometers = bilby.gw.detector.InterferometerList(["H1"]) interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.xp.array(2 * maximum_frequency), - duration=self.xp.array(duration), - start_time=self.xp.array(self.injection_parameters["geocent_time"] - duration + 1) + sampling_frequency=self.xp.asarray(2 * maximum_frequency), + duration=self.xp.asarray(duration), + start_time=self.xp.asarray(self.injection_parameters["geocent_time"] - duration + 1) ) interferometers.set_array_backend(self.xp) for ifo in interferometers: @@ -793,9 +794,9 @@ def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): self.priors["chirp_mass"].maximum = chirp_mass_max interferometers = bilby.gw.detector.InterferometerList(["H1"]) interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.xp.array(self.sampling_frequency), - duration=self.xp.array(self.duration), - start_time=self.xp.array(self.injection_parameters["geocent_time"] - self.duration + 1) + sampling_frequency=self.xp.asarray(self.sampling_frequency), + duration=self.xp.asarray(self.duration), + start_time=self.xp.asarray(self.injection_parameters["geocent_time"] - self.duration + 1) ) interferometers.set_array_backend(self.xp) for ifo in interferometers: @@ -1011,7 +1012,7 @@ def assertLess_likelihood_errors( llr = likelihood.log_likelihood_ratio(parameters) llr_roq = likelihood_roq.log_likelihood_ratio(parameters) self.assertLess(np.abs(llr - llr_roq), max_llr_error) - self.assertEqual(llr_roq.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr_roq), self.xp) likelihood.parameters.update(parameters) likelihood_roq.parameters.update(parameters) llr = likelihood.log_likelihood_ratio() @@ -1348,9 +1349,9 @@ def test_instantiation(self): @pytest.mark.usefixtures("xp_class") class TestMBLikelihood(unittest.TestCase): def setUp(self): - self.duration = self.xp.array(16.0) - self.fmin = self.xp.array(20.0) - self.sampling_frequency = self.xp.array(2048.0) + self.duration = self.xp.asarray(16.0) + self.fmin = self.xp.asarray(20.0) + self.sampling_frequency = self.xp.asarray(2048.0) self.test_parameters = dict( chirp_mass=6.0, mass_ratio=0.5, @@ -1462,7 +1463,7 @@ def test_matches_original_likelihood( abs(likelihood.log_likelihood_ratio(parameters) - llmb), tolerance ) - self.assertEqual(llmb.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llmb), self.xp) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) self.assertLess( @@ -1636,7 +1637,7 @@ def test_inout_weights(self, linear_interpolation): # likelihood_mb_from_weights.parameters.update(self.test_parameters) llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio(self.test_parameters) - self.assertEqual(llr_from_weights.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr_from_weights), self.xp) self.assertAlmostEqual(llr, llr_from_weights) @@ -1673,7 +1674,7 @@ def test_from_dict_weights(self, linear_interpolation): ) likelihood_mb.parameters.update(self.test_parameters) llr = likelihood_mb.log_likelihood_ratio() - self.assertEqual(llr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr), self.xp) # reset waveform generator to check if likelihood recovered from the weights properly adds banded # frequency points to waveform arguments @@ -1691,7 +1692,7 @@ def test_from_dict_weights(self, linear_interpolation): ) # likelihood_mb_from_weights.parameters.update(self.test_parameters) llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio(self.test_parameters) - self.assertEqual(llr_from_weights.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llr_from_weights), self.xp) self.assertAlmostEqual(llr, llr_from_weights) @@ -1744,7 +1745,7 @@ def test_matches_original_likelihood_low_maximum_frequency( abs(likelihood.log_likelihood_ratio(parameters) - llrmb), tolerance ) - self.assertEqual(llrmb.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(llrmb), self.xp) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) self.assertLess( diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index c4afa00f4..aec6f01b3 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -5,6 +5,7 @@ import sys import pickle +import array_api_compat as aac import numpy as np from astropy import cosmology from scipy.stats import ks_2samp @@ -582,7 +583,7 @@ def test_non_analytic_form_has_correct_statistics(self): alts = a_prior.sample(100000, xp=self.xp) * z_prior.sample(100000, xp=self.xp) self.assertAlmostEqual(np.mean(chis), np.mean(alts), 2) self.assertAlmostEqual(np.std(chis), np.std(alts), 2) - self.assertEqual(chis.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(chis), self.xp) @pytest.mark.array_backend @@ -600,7 +601,7 @@ def test_marginalized_prior_is_uniform(self): samples = priors.sample(100000, xp=self.xp)["a_1"] ks = ks_2samp(samples, np.random.uniform(0, priors["chi_1"].maximum, 100000)) self.assertTrue(ks.pvalue > 0.001) - self.assertEqual(samples.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(samples), self.xp) if __name__ == "__main__": diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index d082a01a1..80b8fe60d 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -3,6 +3,7 @@ from shutil import rmtree from importlib.metadata import version +import array_api_compat as aac import numpy as np import lal import lalsimulation as lalsim @@ -29,32 +30,32 @@ def tearDown(self): pass def test_asd_from_freq_series(self): - freq_data = self.xp.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) - self.assertEqual(asd.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(asd), self.xp) def test_psd_from_freq_series(self): - freq_data = self.xp.array([1, 2, 3]) + freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) - self.assertEqual(psd.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(psd), self.xp) def test_inner_product(self): - aa = self.xp.array([1, 2, 3]) - bb = self.xp.array([5, 6, 7]) - frequency = self.xp.array([0.2, 0.4, 0.6]) + aa = self.xp.asarray([1, 2, 3]) + bb = self.xp.asarray([5, 6, 7]) + frequency = self.xp.asarray([0.2, 0.4, 0.6]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() ip = gwutils.inner_product(aa, bb, frequency, PSD) self.assertEqual(ip, 0) - self.assertEqual(ip.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(ip), self.xp) def test_noise_weighted_inner_product(self): - aa = self.xp.array([1e-23, 2e-23, 3e-23]) - bb = self.xp.array([5e-23, 6e-23, 7e-23]) - frequency = self.xp.array([100, 101, 102]) + aa = self.xp.asarray([1e-23, 2e-23, 3e-23]) + bb = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -65,12 +66,12 @@ def test_noise_weighted_inner_product(self): gwutils.optimal_snr_squared(aa, psd, duration), gwutils.noise_weighted_inner_product(aa, aa, psd, duration), ) - self.assertEqual(nwip.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(nwip), self.xp) def test_matched_filter_snr(self): - signal = self.xp.array([1e-23, 2e-23, 3e-23]) - frequency_domain_strain = self.xp.array([5e-23, 6e-23, 7e-23]) - frequency = self.xp.array([100, 101, 102]) + signal = self.xp.asarray([1e-23, 2e-23, 3e-23]) + frequency_domain_strain = self.xp.asarray([5e-23, 6e-23, 7e-23]) + frequency = self.xp.asarray([100, 101, 102]) PSD = bilby.gw.detector.PowerSpectralDensity.from_aligo() psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 @@ -79,7 +80,7 @@ def test_matched_filter_snr(self): signal, frequency_domain_strain, psd, duration ) self.assertEqual(mfsnr, 25.510869054168282) - self.assertEqual(mfsnr.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(mfsnr), self.xp) def test_overlap(self): signal = self.xp.linspace(1e-23, 21e-23, 21) @@ -99,7 +100,7 @@ def test_overlap(self): norm_b=gwutils.optimal_snr_squared(frequency_domain_strain, psd, duration), ) self.assertAlmostEqual(overlap, 2.76914407e-05) - self.assertEqual(overlap.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(overlap), self.xp) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): @@ -319,8 +320,8 @@ def test_conversion_single(self) -> None: ra, dec = bilby.gw.utils.zenith_azimuth_to_ra_dec( zenith, azimuth, time, self.ifos ) - self.assertEqual(ra.__array_namespace__(), self.xp) - self.assertEqual(dec.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(ra), self.xp) + self.assertEqual(aac.get_namespace(dec), self.xp) def test_conversion_gives_correct_prior(self) -> None: zeniths = self.xp.asarray(self.samples["zenith"]) @@ -332,8 +333,8 @@ def test_conversion_gives_correct_prior(self) -> None: ) self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) - self.assertEqual(ras.__array_namespace__(), self.xp) - self.assertEqual(decs.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(ras), self.xp) + self.assertEqual(aac.get_namespace(decs), self.xp) @pytest.mark.array_backend diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index efd59d352..b413498a9 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -1,6 +1,7 @@ import unittest from unittest import mock +import array_api_compat as aac import bilby import lalsimulation import numpy as np @@ -42,7 +43,7 @@ def dummy_func_dict_return_value( def dummy_func_array_return_value_2( array, amplitude, mu, sigma, ra, dec, geocent_time, psi, *, xp=None ): - return dict(plus=xp.array(array), cross=xp.array(array)) + return dict(plus=xp.asarray(array), cross=xp.asarray(array)) @pytest.mark.array_backend @@ -50,8 +51,8 @@ def dummy_func_array_return_value_2( class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - self.xp.array(1.0), - self.xp.array(4096.0), + self.xp.asarray(1.0), + self.xp.asarray(4096.0), frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -125,11 +126,11 @@ def conversion_func(): def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) - self.assertEqual(self.waveform_generator.duration.__array_namespace__(), self.xp) + self.assertEqual(self.waveform_generator.aac.get_namespace(duration), self.xp) def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) - self.assertEqual(self.waveform_generator.sampling_frequency.__array_namespace__(), self.xp) + self.assertEqual(self.waveform_generator.aac.get_namespace(sampling_frequency), self.xp) def test_source_model(self): self.assertEqual( @@ -315,8 +316,8 @@ def conversion_func(): class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.xp.array(1.0), - sampling_frequency=self.xp.array(4096.0), + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), frequency_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -358,8 +359,8 @@ def test_frequency_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) - self.assertEqual(actual["plus"].__array_namespace__(), self.xp) - self.assertEqual(actual["cross"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None @@ -377,7 +378,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) - self.assertEqual(actual.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None @@ -396,8 +397,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) - self.assertEqual(actual["plus"].__array_namespace__(), self.xp) - self.assertEqual(actual["cross"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -507,8 +508,8 @@ def test_frequency_domain_caching_changing_model(self): def test_time_domain_caching_changing_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=self.xp.array(1.0), - sampling_frequency=self.xp.array(4096.0), + duration=self.xp.asarray(1.0), + sampling_frequency=self.xp.asarray(4096.0), time_domain_source_model=dummy_func_dict_return_value, ) original_waveform = self.waveform_generator.frequency_domain_strain( @@ -523,8 +524,8 @@ def test_time_domain_caching_changing_model(self): self.assertFalse( np.array_equal(original_waveform["plus"], new_waveform["plus"]) ) - self.assertEqual(new_waveform["plus"].__array_namespace__(), self.xp) - self.assertEqual(new_waveform["cross"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(new_waveform["plus"]), self.xp) + self.assertEqual(aac.get_namespace(new_waveform["cross"]), self.xp) @pytest.mark.array_backend @@ -532,8 +533,8 @@ def test_time_domain_caching_changing_model(self): class TestTimeDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - self.xp.array(1.0), - self.xp.array(4096.0), + self.xp.asarray(1.0), + self.xp.asarray(4096.0), time_domain_source_model=dummy_func_dict_return_value, ) self.simulation_parameters = dict( @@ -575,8 +576,8 @@ def test_time_domain_source_model_call(self): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) - self.assertEqual(actual["plus"].__array_namespace__(), self.xp) - self.assertEqual(actual["cross"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None @@ -596,7 +597,7 @@ def side_effect(value, value2): parameters=self.simulation_parameters ) self.assertTrue(np.array_equal(expected, actual)) - self.assertEqual(actual.__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual), self.xp) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None @@ -617,8 +618,8 @@ def side_effect(value, value2): ) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) - self.assertEqual(actual["plus"].__array_namespace__(), self.xp) - self.assertEqual(actual["cross"].__array_namespace__(), self.xp) + self.assertEqual(aac.get_namespace(actual["plus"]), self.xp) + self.assertEqual(aac.get_namespace(actual["cross"]), self.xp) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None From eea4f243bc6e0a3db717f344e79d61bdddaa5161 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 12:11:50 -0500 Subject: [PATCH 088/110] BUG: fix some broken formatting --- test/core/prior/slabspike_test.py | 4 ++-- test/core/utils_test.py | 6 +++--- test/gw/detector/geometry_test.py | 28 ++++++++++++++-------------- test/gw/utils_test.py | 4 ++-- test/gw/waveform_generator_test.py | 5 ++--- 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index c8a07aae8..1ec76ab71 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -129,8 +129,8 @@ def test_ln_prob_on_slab(self): def test_ln_prob_on_spike(self): for slab_spike in self.slab_spikes: - expected = slab_spike.ln_prob(self.spike_loc) - self.assertEqual(np.inf, expected) + actual = slab_spike.ln_prob(self.spike_loc) + self.assertEqual(np.inf, actual) self.assertEqual(aac.get_namespace(actual), self.xp) def test_inverse_cdf_below_spike_with_spike_at_minimum(self): diff --git a/test/core/utils_test.py b/test/core/utils_test.py index bd1869a6a..f766f2c74 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -361,11 +361,11 @@ def test_returns_none_for_floats_outside_range(self): def test_returns_float_for_float_and_array(self): input_array = self.xp.asarray(np.random.random(10)) - aac.get_namespace(self.assertEqual(self.interpolant(input_array, 0.5)), self.xp) - aac.get_namespace(self.assertEqual( + self.assertEqual(aac.get_namespace(self.interpolant(input_array, 0.5)), self.xp) + self.assertEqual(aac.get_namespace( self.interpolant(input_array, input_array)), self.xp ) - aac.get_namespace(self.assertEqual(self.interpolant(0.5, input_array)), self.xp) + self.assertEqual(aac.get_namespace(self.interpolant(0.5, input_array)), self.xp) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index b18eb218d..231b82b17 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -45,35 +45,35 @@ def tearDown(self): def test_length_setting(self): self.assertEqual(self.geometry.length, self.length) - self.assertEqual(self.geometry.aac.get_namespace(length), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.length), self.xp) def test_latitude_setting(self): self.assertEqual(self.geometry.latitude, self.latitude) - self.assertEqual(self.geometry.aac.get_namespace(latitude), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.latitude), self.xp) def test_longitude_setting(self): self.assertEqual(self.geometry.longitude, self.longitude) - self.assertEqual(self.geometry.aac.get_namespace(longitude), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.longitude), self.xp) def test_elevation_setting(self): self.assertEqual(self.geometry.elevation, self.elevation) - self.assertEqual(self.geometry.aac.get_namespace(elevation), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.elevation), self.xp) def test_xarm_azi_setting(self): self.assertEqual(self.geometry.xarm_azimuth, self.xarm_azimuth) - self.assertEqual(self.geometry.aac.get_namespace(xarm_azimuth), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.xarm_azimuth), self.xp) def test_yarm_azi_setting(self): self.assertEqual(self.geometry.yarm_azimuth, self.yarm_azimuth) - self.assertEqual(self.geometry.aac.get_namespace(yarm_azimuth), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.yarm_azimuth), self.xp) def test_xarm_tilt_setting(self): self.assertEqual(self.geometry.xarm_tilt, self.xarm_tilt) - self.assertEqual(self.geometry.aac.get_namespace(xarm_tilt), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.xarm_tilt), self.xp) def test_yarm_tilt_setting(self): self.assertEqual(self.geometry.yarm_tilt, self.yarm_tilt) - self.assertEqual(self.geometry.aac.get_namespace(yarm_tilt), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.yarm_tilt), self.xp) def test_vertex_without_update(self): _ = self.geometry.vertex @@ -155,37 +155,37 @@ def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) - self.assertEqual(self.geometry.aac.get_namespace(detector_tensor), self.xp) + self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_unit_vector_along_arm_default(self): with self.assertRaises(ValueError): diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index 80b8fe60d..679956b62 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -31,14 +31,14 @@ def tearDown(self): def test_asd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) - df = 0.1 + df = self.xp.asarray(0.1) asd = gwutils.asd_from_freq_series(freq_data, df) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) self.assertEqual(aac.get_namespace(asd), self.xp) def test_psd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) - df = 0.1 + df = self.xp.asarray(0.1) psd = gwutils.psd_from_freq_series(freq_data, df) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) self.assertEqual(aac.get_namespace(psd), self.xp) diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index b413498a9..5c883e8ab 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -126,12 +126,11 @@ def conversion_func(): def test_duration(self): self.assertEqual(self.waveform_generator.duration, 1) - self.assertEqual(self.waveform_generator.aac.get_namespace(duration), self.xp) + self.assertEqual(aac.get_namespace(self.waveform_generator.duration), self.xp) def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) - self.assertEqual(self.waveform_generator.aac.get_namespace(sampling_frequency), self.xp) - + self.assertEqual(aac.get_namespace(self.waveform_generator.sampling_frequency), self.xp) def test_source_model(self): self.assertEqual( self.waveform_generator.frequency_domain_source_model, From 0a17a5da85a45944b06a3c2f28d2553c726c3c25 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 13:36:03 -0500 Subject: [PATCH 089/110] FMT: fix formatting --- bilby/core/prior/analytical.py | 2 +- test/conftest.py | 2 +- test/core/prior/analytical_test.py | 1 - test/gw/waveform_generator_test.py | 1 + 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 9dd120964..11c1aa038 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1530,7 +1530,7 @@ def prob(self, val, *, xp=None): index = xp.searchsorted(xp.asarray(self._values_array), val) index = xp.clip(index, 0, self.nvalues - 1) p = xp.where( - xp.asarray(self._values_array[index])== val, + xp.asarray(self._values_array[index]) == val, xp.asarray(self._weights_array[index]), xp.asarray(0.0), ) diff --git a/test/conftest.py b/test/conftest.py index 2bcd416cd..fda4a4387 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -50,7 +50,7 @@ def _xp(request): xp = jax.numpy case _: try: - xp = importlib.import_module(backend) + xp = importlib.import_module(backend) except ImportError: raise ValueError(f"Unknown backend for testing: {backend}") return aac.get_namespace(xp.ones(1)) diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index dddc0dfba..ec4ec975b 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -201,7 +201,6 @@ def test_array_sample(self): self.assertAlmostEqual(case / N, categorical_prior.prob(i), places=int(np.log10(np.sqrt(N)))) self.assertAlmostEqual(case / N, weights[i] / np.sum(weights), places=int(np.log10(np.sqrt(N)))) self.assertEqual(cases, N) - def test_single_probability(self): N = 3 diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index 5c883e8ab..8a98369b5 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -131,6 +131,7 @@ def test_duration(self): def test_sampling_frequency(self): self.assertEqual(self.waveform_generator.sampling_frequency, 4096) self.assertEqual(aac.get_namespace(self.waveform_generator.sampling_frequency), self.xp) + def test_source_model(self): self.assertEqual( self.waveform_generator.frequency_domain_source_model, From 713ca68ec925d1f2e74958a9d842267e08ab250b Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 13:43:41 -0500 Subject: [PATCH 090/110] BUG: fix bugs in testing --- bilby/compat/utils.py | 15 ++++++++++++--- test/core/likelihood_test.py | 12 ++++++------ test/gw/utils_test.py | 4 ++-- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 34b819ab9..5c8e7dd95 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Iterable import numpy as np @@ -61,8 +62,7 @@ def xp_wrap(func, no_xp=False): function The decorated function. """ - - def wrapped(self, *args, xp=None, **kwargs): + def parse_args_kwargs_for_xp(*args, xp=None, **kwargs): if not no_xp and xp is None: try: # if the user specified the target arrays in kwargs @@ -78,7 +78,16 @@ def wrapped(self, *args, xp=None, **kwargs): kwargs["xp"] = np elif not no_xp: kwargs["xp"] = xp - return func(self, *args, **kwargs) + return args, kwargs + + if inspect.isfunction(func): + def wrapped(*args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(*args, **kwargs) + else: + def wrapped(self, *args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(self, *args, **kwargs) return wrapped diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index ab12208c8..368b7f1e5 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -58,8 +58,8 @@ def test_meta_data(self): @pytest.mark.usefixtures("xp_class") class TestAnalytical1DLikelihood(unittest.TestCase): def setUp(self): - self.x = self.xp.arange(start=0, stop=100, step=1) - self.y = self.xp.arange(start=0, stop=100, step=1) + self.x = self.xp.arange(0, 100, step=1) + self.y = self.xp.arange(0, 100, step=1) def test_func(x, parameter1, parameter2): return parameter1 * x + parameter2 @@ -85,7 +85,7 @@ def test_init_x(self): self.assertTrue(np.array_equal(self.x, self.analytical_1d_likelihood.x)) def test_set_x_to_array(self): - new_x = self.xp.arange(start=0, stop=50, step=2) + new_x = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.x = new_x self.assertTrue(np.array_equal(new_x, self.analytical_1d_likelihood.x)) @@ -105,7 +105,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.analytical_1d_likelihood.y)) def test_set_y_to_array(self): - new_y = self.xp.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.analytical_1d_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.analytical_1d_likelihood.y)) @@ -355,7 +355,7 @@ def test_init_y(self): self.assertTrue(self.xp.array_equal(self.y, self.poisson_likelihood.y)) def test_set_y_to_array(self): - new_y = self.xp.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.poisson_likelihood.y = new_y self.assertTrue(self.xp.array_equal(new_y, self.poisson_likelihood.y)) @@ -459,7 +459,7 @@ def test_init_y(self): self.assertTrue(np.array_equal(self.y, self.exponential_likelihood.y)) def test_set_y_to_array(self): - new_y = self.xp.arange(start=0, stop=50, step=2) + new_y = self.xp.arange(0, 50, step=2) self.exponential_likelihood.y = new_y self.assertTrue(np.array_equal(new_y, self.exponential_likelihood.y)) diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index 679956b62..80b8fe60d 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -31,14 +31,14 @@ def tearDown(self): def test_asd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) - df = self.xp.asarray(0.1) + df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) self.assertEqual(aac.get_namespace(asd), self.xp) def test_psd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) - df = self.xp.asarray(0.1) + df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) self.assertEqual(aac.get_namespace(psd), self.xp) From e0c41db3a9514a6e88478b63ceeb97058f9dbfbc Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 2 Feb 2026 14:16:46 -0500 Subject: [PATCH 091/110] Fix some more conversions --- bilby/compat/utils.py | 27 +++++++++++++++++---------- bilby/core/utils/calculus.py | 15 +++++++-------- test/gw/likelihood_test.py | 10 ++++++---- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 5c8e7dd95..7c2169d09 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -66,28 +66,35 @@ def parse_args_kwargs_for_xp(*args, xp=None, **kwargs): if not no_xp and xp is None: try: # if the user specified the target arrays in kwargs - # we need to be able to support this + # we need to be able to support this, if there is + # only one kwargs, pass it through alone, this is + # sometimes a dictionary of arrays so this is needed + # to remove a level of nesting if len(args) > 0: - xp = array_module(*args) - elif len(kwargs) > 0: - xp = array_module(*kwargs.values()) + xp = array_module(args) + elif len(kwargs) == 1: + xp = array_module(next(iter(kwargs.values()))) + elif len(kwargs) > 1: + xp = array_module(kwargs) else: xp = np kwargs["xp"] = xp - except TypeError: + except TypeError as e: + print("type failed", e) kwargs["xp"] = np elif not no_xp: kwargs["xp"] = xp return args, kwargs - if inspect.isfunction(func): - def wrapped(*args, xp=None, **kwargs): - args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) - return func(*args, **kwargs) - else: + sig = inspect.signature(func) + if any(name in sig.parameters for name in ("self", "cls")): def wrapped(self, *args, xp=None, **kwargs): args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) return func(self, *args, **kwargs) + else: + def wrapped(*args, xp=None, **kwargs): + args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) + return func(*args, **kwargs) return wrapped diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 6852299ec..01a92b9c8 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -175,13 +175,14 @@ def logtrapzexp(lnf, dx, *, xp=np): lnfdx1 = lnf[:-1] lnfdx2 = lnf[1:] - if ( - isinstance(dx, (int, float)) or - (aac.is_array_api_obj(dx) and dx.size == 1) - ): - C = np.log(dx / 2.0) - elif isinstance(dx, (list, xp.ndarray)): + try: dx = xp.asarray(dx) + except TypeError: + raise TypeError(f"Step size dx={dx} could not be converted to an array") + + if dx.size == 1: + C = np.log(dx / 2.0) + else: if dx.size != len(lnf) - 1: raise ValueError( "Step size array must have length one less than the function length" @@ -191,8 +192,6 @@ def logtrapzexp(lnf, dx, *, xp=np): lnfdx1 = lnfdx1.copy() + lndx lnfdx2 = lnfdx2.copy() + lndx C = -xp.log(2.0) - else: - raise TypeError("Step size must be a single value or array-like") return C + logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 35f4de99b..c2dcf1529 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -319,15 +319,17 @@ def test_time_reference_agrees_with_default(self): ) parameters = self.parameters.copy() parameters["H1_time"] = parameters["geocent_time"] + time_delay - self.assertEqual( + self.assertAlmostEqual( new_likelihood.log_likelihood_ratio(parameters), - self.likelihood.log_likelihood_ratio(parameters) + self.likelihood.log_likelihood_ratio(parameters), + 8, ) new_likelihood.parameters.update(parameters) self.likelihood.parameters.update(parameters) - self.assertEqual( + self.assertAlmostEqual( new_likelihood.log_likelihood_ratio(), - self.likelihood.log_likelihood_ratio() + self.likelihood.log_likelihood_ratio(), + 8, ) From 59512294bcaf73f7315ad415543ebf720e19c341 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 08:23:11 -0500 Subject: [PATCH 092/110] Add pytorch core testing --- .github/workflows/unit-tests.yml | 3 ++ bilby/compat/patches.py | 4 +- bilby/compat/utils.py | 2 + bilby/core/grid.py | 35 ++++++++++---- bilby/core/likelihood.py | 7 +-- bilby/core/prior/analytical.py | 73 ++++++++++++++++-------------- bilby/core/prior/dict.py | 71 ++++++++++++++++------------- bilby/core/prior/joint.py | 20 +++++--- bilby/core/prior/slabspike.py | 7 +-- bilby/core/utils/calculus.py | 20 +++----- bilby/core/utils/io.py | 4 +- bilby/core/utils/series.py | 2 +- bilby/gw/prior.py | 6 +-- requirements.txt | 3 +- test/conftest.py | 12 +++++ test/core/grid_test.py | 16 +++---- test/core/likelihood_test.py | 30 ++++++------ test/core/prior/analytical_test.py | 22 ++++----- test/core/prior/base_test.py | 2 + test/core/prior/prior_test.py | 72 ++++++++++++++++------------- test/core/prior/slabspike_test.py | 15 +++--- test/core/series_test.py | 11 +++-- test/core/utils_test.py | 5 ++ 23 files changed, 256 insertions(+), 186 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 395248530..6bce425e2 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -64,6 +64,9 @@ jobs: run: | python -m pip install .[jax] pytest --array-backend jax --durations 10 + - name: Run torch-backend unit tests + run: | + pytest --array-backend torch --durations 10 test/core - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/bilby/compat/patches.py b/bilby/compat/patches.py index 19ad0565a..53345abd4 100644 --- a/bilby/compat/patches.py +++ b/bilby/compat/patches.py @@ -41,8 +41,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=Non if xp is None: xp = aac.get_namespace(a) - if "jax" in xp.__name__: - # the scipy version of logsumexp cannot be vmapped + # the scipy version of logsumexp cannot be vmapped + if aac.is_jax_namespace(xp): from jax.scipy.special import logsumexp as lse else: from scipy.special import logsumexp as lse diff --git a/bilby/compat/utils.py b/bilby/compat/utils.py index 7c2169d09..a05cc0920 100644 --- a/bilby/compat/utils.py +++ b/bilby/compat/utils.py @@ -10,6 +10,8 @@ def array_module(arr): + if isinstance(arr, tuple) and len(arr) == 1: + arr = arr[0] try: return array_namespace(arr) except TypeError: diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 2377dc0d5..4d1a7501c 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -1,5 +1,6 @@ import json import os +from copy import copy import array_api_compat as aac import numpy as np @@ -169,7 +170,7 @@ def marginalize(self, log_array, parameters=None, not_parameters=None): else: raise TypeError("Parameters names must be a list or string") - out_array = log_array.copy() + out_array = copy(log_array) names = list(self.parameter_names) for name in params: @@ -212,9 +213,17 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): if len(places) > 1: dx = xp.diff(places) - out = xp.apply_along_axis( - logtrapzexp, axis, log_array, dx - ) + if log_array.ndim == 1: + out = logtrapzexp(log_array, dx=dx, xp=xp) + elif aac.is_torch_namespace(xp): + # https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440 + out = xp.stack([ + logtrapzexp(x_i, dx=dx, xp=xp) for x_i in xp.unbind(log_array, dim=axis) + ], dim=min(axis, log_array.ndim - 2)) + else: + out = xp.apply_along_axis( + logtrapzexp, axis, log_array, dx + ) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape @@ -327,8 +336,11 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): def _evaluate(self): xp = aac.get_namespace(self.mesh_grid[0]) - if xp.__name__ == "jax.numpy": - from jax import vmap + if aac.is_torch_namespace(xp) or aac.is_jax_namespace(xp): + if aac.is_torch_namespace(xp): + from torch import vmap + else: + from jax import vmap self._ln_likelihood = vmap(self.likelihood.log_likelihood)( {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)} ).reshape(self.mesh_grid[0].shape) @@ -364,13 +376,13 @@ def _get_sample_points(self, grid_size, *, xp=np): self.sample_points[key] = self.priors[key].rescale( xp.linspace(0, 1, grid_size[ii])) else: - self.sample_points[key] = grid_size[ii] + self.sample_points[key] = xp.asarray(grid_size[ii]) elif isinstance(grid_size, dict): if isinstance(grid_size[key], int): self.sample_points[key] = self.priors[key].rescale( xp.linspace(0, 1, grid_size[key])) else: - self.sample_points[key] = grid_size[key] + self.sample_points[key] = xp.asarray(grid_size[key]) else: raise TypeError("Unrecognized 'grid_size' type") @@ -451,7 +463,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, "following message:\n {} \n\n".format(e)) @classmethod - def read(cls, filename=None, outdir=None, label=None, gzip=False): + def read(cls, filename=None, outdir=None, label=None, gzip=False, xp=None): """ Read in a saved .json grid file Parameters @@ -464,6 +476,9 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): If given, whether the file is gzipped or not (only required if the file is gzipped, but does not have the standard '.gz' file extension) + xp: array module | None + The array module to use for calculations (e.g., :code:`numpy`, + :code:`jax.numpy`). If :code:`None`, defaults to :code:`numpy`. Returns ======= @@ -487,7 +502,7 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): try: grid = cls(likelihood=None, priors=dictionary['priors'], grid_size=dictionary['sample_points'], - label=dictionary['label'], outdir=dictionary['outdir']) + label=dictionary['label'], outdir=dictionary['outdir'], xp=xp) # set the likelihood grid._ln_likelihood = dictionary['ln_likelihood'] diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index da333f534..510b6bfce 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -288,7 +288,7 @@ def log_likelihood(self, parameters=None): xp = array_module(self.x) sigma = parameters.get("sigma", self.sigma) log_l = xp.sum(- (self.residual(parameters) / sigma)**2 / 2 - - xp.log(2 * np.pi * sigma**2) / 2) + xp.log(xp.asarray(2 * np.pi * sigma**2)) / 2) return log_l def __repr__(self): @@ -374,7 +374,8 @@ def y(self, y): y = np.atleast_1d(y) xp = aac.get_namespace(y) # check array is a non-negative integer array - if y.dtype.kind not in 'ui' or xp.any(y < 0): + # torch doesn't support checking dtype kind + if (not aac.is_torch_namespace(xp) and y.dtype.kind not in 'ui') or xp.any(y < 0): raise ValueError("Data must be non-negative integers") self.__y = y @@ -468,7 +469,7 @@ def log_likelihood(self, parameters=None): xp = array_module(self.x) log_l =\ xp.sum(- (nu + 1) * xp.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + - xp.log(self.lam / (nu * np.pi)) / 2 + + xp.log(xp.asarray(self.lam / (nu * np.pi))) / 2 + gammaln((nu + 1) / 2) - gammaln(nu / 2)) return log_l diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 11c1aa038..d0a9bc22e 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -122,7 +122,7 @@ def rescale(self, val, *, xp=None): Union[float, array_like]: Rescaled probability """ if self.alpha == -1: - return self.minimum * xp.exp(val * xp.log(self.maximum / self.minimum)) + return self.minimum * xp.exp(val * xp.log(xp.asarray(self.maximum / self.minimum))) else: return (self.minimum ** (1 + self.alpha) + val * (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) @@ -140,7 +140,7 @@ def prob(self, val, *, xp=None): float: Prior probability of val """ if self.alpha == -1: - return xp.nan_to_num(1 / val / xp.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) + return xp.nan_to_num(1 / val / xp.log(xp.asarray(self.maximum / self.minimum))) * self.is_in_prior_range(val) else: return xp.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - @@ -160,10 +160,11 @@ def ln_prob(self, val, *, xp=None): """ if self.alpha == -1: - normalising = 1. / xp.log(self.maximum / self.minimum) + normalising = 1. / xp.log(xp.asarray(self.maximum / self.minimum)) else: - normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - - self.minimum ** (1 + self.alpha)) + normalising = (1 + self.alpha) / xp.asarray( + self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha) + ) with np.errstate(divide='ignore', invalid='ignore'): ln_in_range = xp.log(1. * self.is_in_prior_range(val)) @@ -175,7 +176,7 @@ def ln_prob(self, val, *, xp=None): def cdf(self, val, *, xp=None): if self.alpha == -1: with np.errstate(invalid="ignore"): - _cdf = xp.log(val / self.minimum) / xp.log(self.maximum / self.minimum) + _cdf = xp.log(val / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) else: _cdf = ( (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) @@ -334,7 +335,11 @@ def rescale(self, val, *, xp=None): ======= Union[float, array_like]: Rescaled probability """ - return xp.sign(2 * val - 1) * self.minimum * xp.exp(abs(2 * val - 1) * xp.log(self.maximum / self.minimum)) + return ( + xp.sign(2 * val - 1) + * self.minimum + * xp.exp(xp.abs(2 * val - 1) * xp.log(xp.asarray(self.maximum / self.minimum))) + ) @xp_wrap def prob(self, val, *, xp=None): @@ -349,7 +354,7 @@ def prob(self, val, *, xp=None): float: Prior probability of val """ val = xp.abs(val) - return (xp.nan_to_num(0.5 / val / xp.log(self.maximum / self.minimum)) * + return (xp.nan_to_num(0.5 / val / xp.log(xp.asarray(self.maximum / self.minimum))) * self.is_in_prior_range(val)) @xp_wrap @@ -365,11 +370,11 @@ def ln_prob(self, val, *, xp=None): float: """ - return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(self.maximum / self.minimum))) + return xp.nan_to_num(- xp.log(2 * xp.abs(val)) - xp.log(xp.log(xp.asarray(self.maximum / self.minimum)))) @xp_wrap def cdf(self, val, *, xp=None): - asymmetric = xp.log(abs(val) / self.minimum) / xp.log(self.maximum / self.minimum) + asymmetric = xp.log(xp.abs(val) / self.minimum) / xp.log(xp.asarray(self.maximum / self.minimum)) return xp.clip(0.5 * (1 + xp.sign(val) * asymmetric), 0, 1) @@ -404,8 +409,8 @@ def rescale(self, val, *, xp=None): This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (xp.sin(self.maximum) - xp.sin(self.minimum)) - return xp.arcsin(val / norm + xp.sin(self.minimum)) + norm = 1 / (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) + return xp.arcsin(val / norm + xp.sin(xp.asarray(self.minimum))) @xp_wrap def prob(self, val, *, xp=None): @@ -424,8 +429,8 @@ def prob(self, val, *, xp=None): @xp_wrap def cdf(self, val, *, xp=None): _cdf = ( - (xp.sin(val) - xp.sin(self.minimum)) / - (xp.sin(self.maximum) - xp.sin(self.minimum)) + (xp.sin(val) - xp.sin(xp.asarray(self.minimum))) / + (xp.sin(xp.asarray(self.maximum)) - xp.sin(xp.asarray(self.minimum))) ) _cdf *= val >= self.minimum _cdf *= val <= self.maximum @@ -464,8 +469,8 @@ def rescale(self, val, *, xp=None): This maps to the inverse CDF. This has been analytically solved for this case. """ - norm = 1 / (xp.cos(self.minimum) - xp.cos(self.maximum)) - return xp.arccos(xp.cos(self.minimum) - val / norm) + norm = 1 / (xp.cos(xp.asarray(self.minimum)) - xp.cos(xp.asarray(self.maximum))) + return xp.arccos(xp.cos(xp.asarray(self.minimum)) - val / norm) @xp_wrap def prob(self, val, *, xp=None): @@ -484,8 +489,8 @@ def prob(self, val, *, xp=None): @xp_wrap def cdf(self, val, *, xp=None): _cdf = ( - (xp.cos(val) - xp.cos(self.minimum)) - / (xp.cos(self.maximum) - xp.cos(self.minimum)) + (xp.cos(val) - xp.cos(xp.asarray(self.minimum))) + / (xp.cos(xp.asarray(self.maximum)) - xp.cos(xp.asarray(self.minimum))) ) _cdf *= val >= self.minimum _cdf *= val <= self.maximum @@ -560,7 +565,7 @@ def ln_prob(self, val, *, xp=None): ======= Union[float, array_like]: Prior probability of val """ - return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(2 * np.pi * self.sigma ** 2)) + return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + xp.log(xp.asarray(2 * np.pi * self.sigma ** 2))) def cdf(self, val, *, xp=None): return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 @@ -754,15 +759,15 @@ def ln_prob(self, val, *, xp=None): """ with np.errstate(divide="ignore", invalid="ignore"): return xp.nan_to_num(( - -(xp.log(xp.maximum(val, self.minimum)) - self.mu) ** 2 / self.sigma ** 2 / 2 - - xp.log(xp.sqrt(2 * np.pi) * val * self.sigma) + -(xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) ** 2 / self.sigma ** 2 / 2 + - xp.log((2 * np.pi)**0.5 * val * self.sigma) ) + xp.log(val > self.minimum), nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap def cdf(self, val, *, xp=None): with np.errstate(divide="ignore"): return 0.5 + erf( - (xp.log(xp.maximum(val, self.minimum)) - self.mu) / self.sigma / np.sqrt(2) + (xp.log(xp.maximum(val, xp.asarray(self.minimum))) - self.mu) / self.sigma / np.sqrt(2) ) / 2 @@ -828,12 +833,12 @@ def ln_prob(self, val, *, xp=None): Union[float, array_like]: Prior probability of val """ with np.errstate(divide="ignore"): - return -val / self.mu - xp.log(self.mu) + xp.log(val >= self.minimum) + return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) @xp_wrap def cdf(self, val, *, xp=None): with np.errstate(divide="ignore", invalid="ignore", over="ignore"): - return xp.maximum(1. - xp.exp(-val / self.mu), 0) + return xp.maximum(1. - xp.exp(-val / self.mu), xp.asarray(0.0)) class StudentT(Prior): @@ -916,7 +921,7 @@ def ln_prob(self, val, *, xp=None): """ return ( gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df) - - xp.log((np.pi * self.df)**0.5 * self.scale) - (self.df + 1) / 2 + - xp.log(xp.asarray((np.pi * self.df)**0.5 * self.scale)) - (self.df + 1) / 2 * xp.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) ) @@ -1053,7 +1058,7 @@ def rescale(self, val, *, xp=None): """ with np.errstate(divide="ignore"): val = xp.asarray(val) - return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), 0)) + return self.mu + self.scale * xp.log(xp.maximum(val / (1 - val), xp.asarray(0))) @xp_wrap def prob(self, val, *, xp=None): @@ -1083,7 +1088,7 @@ def ln_prob(self, val, *, xp=None): """ with np.errstate(over="ignore"): return -(val - self.mu) / self.scale -\ - 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(self.scale) + 2. * xp.log1p(xp.exp(-(val - self.mu) / self.scale)) - xp.log(xp.asarray(self.scale)) @xp_wrap def cdf(self, val, *, xp=None): @@ -1155,7 +1160,7 @@ def ln_prob(self, val, *, xp=None): ======= Union[float, array_like]: Log prior probability of val """ - return - xp.log(self.beta * np.pi) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) + return - xp.log(xp.asarray(self.beta * np.pi)) - xp.log(1. + ((val - self.alpha) / self.beta) ** 2) @xp_wrap def cdf(self, val, *, xp=None): @@ -1240,7 +1245,7 @@ def ln_prob(self, val, *, xp=None): @xp_wrap def cdf(self, val, *, xp=None): - return gammainc(xp.asarray(self.k), xp.maximum(val, self.minimum) / self.theta) + return gammainc(xp.asarray(self.k), xp.maximum(val, xp.asarray(self.minimum)) / self.theta) class ChiSquared(Gamma): @@ -1327,7 +1332,7 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, raise ValueError("For the Fermi-Dirac prior the values of sigma and r " "must be positive.") - xp = array_module(np) + xp = array_module((mu, sigma, r)) self.expr = xp.exp(self.r) @xp_wrap @@ -1349,7 +1354,7 @@ def rescale(self, val, *, xp=None): `_, 2017. """ inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr - return -self.sigma * xp.log(xp.maximum(inv, 0)) + return -self.sigma * xp.log(xp.maximum(inv, xp.asarray(0))) @xp_wrap def prob(self, val, *, xp=None): @@ -1365,7 +1370,7 @@ def prob(self, val, *, xp=None): """ return ( (xp.exp((val - self.mu) / self.sigma) + 1)**-1 - / (self.sigma * xp.log1p(self.expr)) + / (self.sigma * xp.log1p(xp.asarray(self.expr))) * (val >= self.minimum) ) @@ -1406,8 +1411,8 @@ def cdf(self, val, *, xp=None): `_, 2017. """ result = ( - (xp.logaddexp(0, -self.r) - xp.logaddexp(-val / self.sigma, -self.r)) - / xp.logaddexp(0, self.r) + (xp.logaddexp(xp.asarray(0.0), -xp.asarray(self.r)) - xp.logaddexp(-val / self.sigma, -xp.asarray(self.r))) + / xp.logaddexp(xp.asarray(0.0), xp.asarray(self.r)) ) return xp.clip(result, 0, 1) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index a7ca7589f..d4d627299 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -5,6 +5,7 @@ from io import open as ioopen from warnings import warn +import array_api_extra as xpx import numpy as np from .analytical import DeltaFunction @@ -58,12 +59,13 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): def __hash__(self): return hash(str(self)) - def evaluate_constraints(self, sample): + @xp_wrap + def evaluate_constraints(self, sample, *, xp=None): out_sample = self.conversion_function(sample) try: - prob = np.ones_like(next(iter(out_sample.values()))) + prob = xp.ones_like(next(iter(out_sample.values())), dtype=bool) except TypeError: - prob = np.ones_like(out_sample) + prob = xp.ones_like(out_sample, dtype=bool) for key in self: if isinstance(self[key], Constraint) and key in out_sample: prob *= self[key].prob(out_sample[key]) @@ -385,7 +387,7 @@ def sample_subset_constrained_as_array(self, keys=iter([]), size=None, *, xp=np) samples_dict = self.sample_subset_constrained(keys=keys, size=size, xp=xp) samples_dict = {key: xp.atleast_1d(val) for key, val in samples_dict.items()} samples_list = [samples_dict[key] for key in keys] - return xp.asarray(samples_list) + return xp.stack(samples_list) def sample_subset(self, keys=iter([]), size=None, *, xp=np): """Draw samples from the prior set for parameters which are not a DeltaFunction @@ -474,18 +476,20 @@ def check_efficiency(n_tested, n_valid): for key in keys.copy(): if isinstance(self[key], Constraint): del keys[keys.index(key)] - all_samples = {key: np.array([]) for key in keys} + all_samples = {key: xp.asarray([]) for key in keys} _first_key = list(all_samples.keys())[0] while len(all_samples[_first_key]) < needed: samples = self.sample_subset(keys=keys, size=needed, xp=xp) - keep = np.array(self.evaluate_constraints(samples), dtype=bool) + keep = self.evaluate_constraints(samples, xp=xp) for key in keys: all_samples[key] = xp.hstack( [all_samples[key], samples[key][keep].flatten()] ) n_tested_samples += needed - n_valid_samples += np.sum(keep) + n_valid_samples += int(xp.sum(keep)) check_efficiency(n_tested_samples, n_valid_samples) + if not isinstance(size, tuple): + size = (size,) all_samples = { key: xp.reshape(all_samples[key][:needed], size) for key in keys } @@ -527,7 +531,8 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk, *, xp=np): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + @xp_wrap + def prob(self, sample, *, xp=None, **kwargs): """ Parameters @@ -542,31 +547,31 @@ def prob(self, sample, **kwargs): float: Joint probability of all individual sample probabilities """ - xp = array_module(sample.values()) - prob = xp.prod(xp.asarray([self[key].prob(sample[key]) for key in sample]), **kwargs) + prob = xp.prod(xp.stack([self[key].prob(sample[key], xp=xp) for key in sample]), **kwargs) - return prob - # return self.check_prob(sample, prob) + return self.check_prob(sample, prob, xp=xp) - def check_prob(self, sample, prob): + @xp_wrap + def check_prob(self, sample, prob, *, xp=None): ratio = self.normalize_constraint_factor(tuple(sample.keys())) - if np.all(prob == 0.0): + if xp.all(prob == 0.0): return prob * ratio else: if isinstance(prob, float): - if self.evaluate_constraints(sample): + if self.evaluate_constraints(sample, xp=xp): return prob * ratio else: return 0.0 else: - constrained_prob = np.zeros_like(prob) - in_bounds = np.isfinite(prob) + constrained_prob = xp.zeros_like(prob) + in_bounds = xp.isfinite(prob) subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.array(self.evaluate_constraints(subsample), dtype=bool) - constrained_prob[in_bounds] = prob[in_bounds] * keep * ratio + keep = self.evaluate_constraints(subsample, xp=xp) + constrained_prob = xpx.at(constrained_prob, in_bounds).set(prob[in_bounds] * keep * ratio) return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + @xp_wrap + def ln_prob(self, sample, axis=None, normalized=True, *, xp=None): """ Parameters @@ -585,29 +590,32 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ln_prob = xp.sum(xp.stack([self[key].ln_prob(sample[key], xp=xp) for key in sample]), axis=axis) return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + normalized=normalized, xp=xp) - def check_ln_prob(self, sample, ln_prob, normalized=True): + @xp_wrap + def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): if normalized: ratio = self.normalize_constraint_factor(tuple(sample.keys())) else: ratio = 1 - if np.all(np.isinf(ln_prob)): + if xp.all(xp.isfinite(ln_prob)): return ln_prob else: if isinstance(ln_prob, float): - if np.all(self.evaluate_constraints(sample)): - return ln_prob + np.log(ratio) + if xp.all(self.evaluate_constraints(sample, xp=xp)): + return ln_prob + xp.log(ratio) else: return -np.inf else: - constrained_ln_prob = -np.inf * np.ones_like(ln_prob) - in_bounds = np.isfinite(ln_prob) + constrained_ln_prob = -np.inf * xp.ones_like(ln_prob) + in_bounds = xp.isfinite(ln_prob) subsample = {key: sample[key][in_bounds] for key in sample} - keep = np.log(np.array(self.evaluate_constraints(subsample), dtype=bool)) - constrained_ln_prob[in_bounds] = ln_prob[in_bounds] + keep + np.log(ratio) + keep = xp.log(self.evaluate_constraints(subsample, xp=xp)) + constrained_ln_prob = xpx.at(constrained_ln_prob, in_bounds).set( + ln_prob[in_bounds] + keep + xp.log(ratio) + ) return constrained_ln_prob @xp_wrap @@ -827,8 +835,7 @@ def prob(self, sample, *, xp=None, **kwargs): for key in sample ]) prob = xp.prod(res, **kwargs) - return prob - # return self.check_prob(sample, prob) + return self.check_prob(sample, prob, xp=xp) def ln_prob(self, sample, axis=None, normalized=True): """ diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 924a24c6e..238c0d791 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -327,7 +327,9 @@ def rescale(self, value, *, xp=None, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) - return xp.squeeze(samp) + if samp.shape[0] == 1: + samp = xp.squeeze(samp, axis=0) + return samp def _rescale(self, samp, **kwargs): """ @@ -627,8 +629,8 @@ def _rescale(self, samp, *, xp=None, **kwargs): samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = xp.asarray(self.mus[mode]) + self.sigmas[mode] * xp.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] + samp = xp.asarray(self.mus[mode]) + xp.asarray(self.sigmas[mode]) * xp.einsum( + "ij,kj->ik", samp * self.sqeigvalues[mode], xp.asarray(self.eigvectors[mode]) ) return samp @@ -684,11 +686,11 @@ def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() z = (samp[j] - self.mus[i]) / self.sigmas[i] lnprob = xpx.at(lnprob, j).set( - xp.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + xp.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - xp.asarray(self.logprodsigmas[i])) ) # set out-of-bounds values to -inf - lnprob = xp.where(outbounds, -xp.inf, lnprob) + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): @@ -801,7 +803,11 @@ def rescale(self, val, *, xp=None, **kwargs): self.dist.rescale_parameters[self.name] = val if self.dist.filled_rescale(): - values = xp.asarray(list(self.dist.rescale_parameters.values())).T + # print(self.dist.rescale_parameters) + values = xp.stack([ + xp.asarray(val) for val in self.dist.rescale_parameters.values() + ]).T + # values = xp.asarray(list(self.dist.rescale_parameters.values())).T samples = self.dist.rescale(values, **kwargs) self.dist.reset_rescale() return samples @@ -878,7 +884,7 @@ def ln_prob(self, val, *, xp=None): "number of requested values." ) - lnp = self.dist.ln_prob(xp.asarray(values).T) + lnp = self.dist.ln_prob(xp.stack(values).T) # reset the requested parameters self.dist.reset_request() diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 23aed86d3..ad4f118b8 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -97,9 +97,10 @@ def rescale(self, val, *, xp=None): val - self.spike_height * higher_indices, xp=xp ) - res = xp.select( - [lower_indices | higher_indices, intermediate_indices], - [slab_scaled, self.spike_location], + res = xp.where( + lower_indices | higher_indices, + slab_scaled, + xp.asarray(self.spike_location), ) return res diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 01a92b9c8..496946c2b 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -180,20 +180,14 @@ def logtrapzexp(lnf, dx, *, xp=np): except TypeError: raise TypeError(f"Step size dx={dx} could not be converted to an array") - if dx.size == 1: - C = np.log(dx / 2.0) - else: - if dx.size != len(lnf) - 1: - raise ValueError( - "Step size array must have length one less than the function length" - ) - - lndx = xp.log(dx) - lnfdx1 = lnfdx1.copy() + lndx - lnfdx2 = lnfdx2.copy() + lndx - C = -xp.log(2.0) + if dx.ndim > 0 and len(dx) != len(lnf) - 1: + raise ValueError( + "Step size array must have length one less than the function length" + ) + lnfdx1 = lnfdx1 + xp.log(dx) + lnfdx2 = lnfdx2 + xp.log(dx) - return C + logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) + return logsumexp(xp.asarray([logsumexp(lnfdx1), logsumexp(lnfdx2)])) - np.log(2) class interp1d(_interp1d): diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index d24b16fe2..f4c9bc4e8 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -60,7 +60,7 @@ def default(self, obj): return encode_astropy_unit(obj) except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") - if hasattr(obj, "__array_namespace__"): + if aac.is_array_api_obj(obj): return { "__array__": True, "__array_namespace__": aac.get_namespace(obj).__name__, @@ -445,7 +445,7 @@ def encode_for_hdf5(key, item): if item.dtype.kind == 'U': logger.debug(f'converting dtype {item.dtype} for hdf5') item = np.array(item, dtype='S') - elif hasattr(item, "__array_namespace__"): + elif aac.is_array_api_obj(item): # temporarily dump all arrays as numpy arrays, we should figure ou # how to properly deserialize them item = np.asarray(item) diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index c60362ab3..4fa20b51a 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -123,7 +123,7 @@ def create_frequency_series(sampling_frequency, duration): """ xp = array_module(sampling_frequency) _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) - number_of_samples = int(xp.round(duration * sampling_frequency)) + number_of_samples = xp.round(duration * sampling_frequency) number_of_frequencies = int(xp.round(number_of_samples / 2) + 1) return xp.linspace(0, sampling_frequency / 2, num=number_of_frequencies) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index cd0cf182d..9127edeb2 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1744,13 +1744,13 @@ def _ln_prob(self, samp, lnprob, outbounds, *, xp=None): phi, dec = samp[0] theta = 0.5 * np.pi - dec pixel = self.hp.ang2pix(self.nside, theta, phi) - xpx.at(lnprob, i).set(xp.log(self.prob[pixel] / self.pixel_area)) + xpx.at(lnprob, i).set(xp.log(xp.asarray(self.prob[pixel] / self.pixel_area))) if self.distance: self.update_distance(pixel) lnprob = xpx.at(lnprob, i).set( - lnprob[i] + xp.log(self.distance_pdf(dist) * dist ** 2) + lnprob[i] + xp.log(xp.asarray(self.distance_pdf(dist) * dist ** 2)) ) - lnprob = xp.where(outbounds, -np.inf, lnprob) + lnprob = xp.where(xp.asarray(outbounds), -np.inf, lnprob) return lnprob def __eq__(self, other): diff --git a/requirements.txt b/requirements.txt index ead66363d..f1a91484d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -array_api_compat +# see https://github.com/data-apis/array-api-compat/pull/341 +array_api_compat>=1.13 array_api_extra dynesty>=2.0.1 emcee diff --git a/test/conftest.py b/test/conftest.py index fda4a4387..a9be96548 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -48,8 +48,20 @@ def _xp(request): os.environ["SCIPY_ARRAY_API"] = "1" jax.config.update("jax_enable_x64", True) xp = jax.numpy + case "torch": + import torch + # torch starts a lot of threads, so disable this on the first import + # to avoid segfaults + try: + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + torch.set_default_dtype(torch.float64) + except RuntimeError: + pass + xp = torch case _: try: + xp = importlib.import_module(backend) except ImportError: raise ValueError(f"Unknown backend for testing: {backend}") diff --git a/test/core/grid_test.py b/test/core/grid_test.py index 009ab2c15..781077f34 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -2,11 +2,11 @@ import shutil import os +import array_api_compat as aac import numpy as np import pytest import bilby -from bilby.compat.patches import multivariate_logpdf class MultiGaussian(bilby.Likelihood): @@ -17,15 +17,13 @@ def __init__(self, mean, cov, *, xp=np): self.cov = xp.asarray(cov) self.mean = xp.asarray(mean) self.sigma = xp.sqrt(xp.diag(self.cov)) - self.logpdf = multivariate_logpdf(xp=xp, mean=self.mean, cov=self.cov) @property def dim(self): return len(self.cov[0]) def log_likelihood(self, parameters): - x = self.xp.asarray([parameters["x{0}".format(i)] for i in range(self.dim)]) - return self.logpdf(x) + return -parameters["x0"]**2 / 2 - parameters["x1"]**2 / 2 - np.log(2 * np.pi) @pytest.mark.array_backend @@ -146,7 +144,9 @@ def test_max_marginalized_likelihood(self): self.assertEqual(1.0, self.grid.marginalize_likelihood(self.grid.parameter_names[1]).max()) def test_ln_evidence(self): - self.assertAlmostEqual(self.expected_ln_evidence, self.grid.ln_evidence, places=5) + ln_z = self.grid.ln_evidence + self.assertEqual(aac.get_namespace(ln_z), self.xp) + self.assertAlmostEqual(self.expected_ln_evidence, float(ln_z), places=5) def test_fail_grid_size(self): with self.assertRaises(TypeError): @@ -218,7 +218,7 @@ def test_grid_from_array(self): def test_save_and_load_from_filename(self): filename = os.path.join("outdir", "test_output.json") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(new_grid.parameter_names, self.grid.parameter_names) self.assertEqual(new_grid.n_dims, self.grid.n_dims) @@ -231,7 +231,7 @@ def test_save_and_load_from_filename(self): def test_save_and_load_from_outdir_label(self): self.grid.save_to_file(overwrite=True, outdir="outdir") - new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label") + new_grid = bilby.core.grid.Grid.read(outdir="outdir", label="label", xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) @@ -248,7 +248,7 @@ def test_save_and_load_from_outdir_label(self): def test_save_and_load_gzip(self): filename = os.path.join("outdir", "test_output.json.gz") self.grid.save_to_file(filename=filename) - new_grid = bilby.core.grid.Grid.read(filename=filename) + new_grid = bilby.core.grid.Grid.read(filename=filename, xp=self.xp) self.assertListEqual(self.grid.parameter_names, new_grid.parameter_names) self.assertEqual(self.grid.n_dims, new_grid.n_dims) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 368b7f1e5..418df8146 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -308,8 +308,8 @@ def setUp(self): self.mu = 5 self.x = self.xp.linspace(0, 1, self.N) self.y = self.xp.asarray(np.random.poisson(self.mu, self.N)) - self.yfloat = self.y.copy() * 1.0 - self.yneg = self.y.copy() + self.yfloat = self.y * 1.0 + self.yneg = self.y * 1.0 self.yneg = xpx.at(self.yneg, 0).set(-1) def test_function(x, c): @@ -335,6 +335,8 @@ def tearDown(self): del self.poisson_likelihood def test_init_y_non_integer(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Torch tensor dtype does not have a 'kind' attribute") with self.assertRaises(ValueError): PoissonLikelihood(self.x, self.yfloat, self.function) @@ -352,12 +354,14 @@ def test_neg_rate_array(self): likelihood.log_likelihood(self.bad_parameters) def test_init_y(self): - self.assertTrue(self.xp.array_equal(self.y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(self.y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(self.y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_array(self): new_y = self.xp.arange(0, 50, step=2) self.poisson_likelihood.y = new_y - self.assertTrue(self.xp.array_equal(new_y, self.poisson_likelihood.y)) + self.assertEqual(aac.get_namespace(new_y), aac.get_namespace(self.poisson_likelihood.y)) + np.testing.assert_array_equal(np.asarray(new_y), np.asarray(self.poisson_likelihood.y)) def test_set_y_to_positive_int(self): new_y = 5 @@ -418,7 +422,7 @@ def setUp(self): self.mu = 5 self.x = self.xp.linspace(0, 1, self.N) self.y = self.xp.asarray(np.random.exponential(self.mu, self.N)) - self.yneg = self.y.copy() + self.yneg = self.y * 1.0 self.yneg = xpx.at(self.yneg, 0).set(-1.0) def test_function(x, c): @@ -510,9 +514,9 @@ def setUp(self): self.sigma = [1, 2, 3] self.mean = [10, 11, 12] if self.xp != np: - self.cov = self.xp.asarray(self.cov) - self.sigma = self.xp.asarray(self.sigma) - self.mean = self.xp.asarray(self.mean) + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean = self.xp.asarray(self.mean, dtype=float) self.likelihood = AnalyticalMultidimensionalCovariantGaussian( mean=self.mean, cov=self.cov ) @@ -538,7 +542,7 @@ def test_dim(self): def test_log_likelihood(self): likelihood = AnalyticalMultidimensionalCovariantGaussian( - mean=self.xp.asarray([0]), cov=self.xp.asarray([1]) + mean=self.xp.asarray([0.0]), cov=self.xp.asarray([1.0]) ) logl = likelihood.log_likelihood(dict(x0=self.xp.asarray(0.0))) self.assertEqual( @@ -557,10 +561,10 @@ def setUp(self): self.mean_1 = [10, 11, 12] self.mean_2 = [20, 21, 22] if self.xp != np: - self.cov = self.xp.asarray(self.cov) - self.sigma = self.xp.asarray(self.sigma) - self.mean_1 = self.xp.asarray(self.mean_1) - self.mean_2 = self.xp.asarray(self.mean_2) + self.cov = self.xp.asarray(self.cov, dtype=float) + self.sigma = self.xp.asarray(self.sigma, dtype=float) + self.mean_1 = self.xp.asarray(self.mean_1, dtype=float) + self.mean_2 = self.xp.asarray(self.mean_2, dtype=float) self.likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( mean_1=self.mean_1, mean_2=self.mean_2, cov=self.cov ) diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index ec4ec975b..09942ba07 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -9,6 +9,10 @@ @pytest.mark.array_backend @pytest.mark.usefixtures("xp_class") class TestDiscreteValuesPrior(unittest.TestCase): + def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("DiscreteValues prior is unstable for torch backend") + def test_single_sample(self): values = [1.1, 1.2, 1.3] discrete_value_prior = bilby.core.prior.DiscreteValues(values) @@ -63,12 +67,9 @@ def test_array_probability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.prob(self.xp.asarray([1.1, 2.2, 2.2, 300.0, 200.0])) - == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) - ) - ) + probs = discrete_value_prior.prob(self.xp.asarray([1.1, 2.2, 2.2, 300.0, 200.0])) + self.assertEqual(aac.get_namespace(probs), self.xp) + np.testing.assert_array_equal(np.asarray(probs), np.array([1 / N] * 4 + [0])) def test_single_lnprobability(self): N = 3 @@ -87,12 +88,9 @@ def test_array_lnprobability(self): N = 3 values = [1.1, 2.2, 300.0] discrete_value_prior = bilby.core.prior.DiscreteValues(values) - self.assertTrue( - np.all( - discrete_value_prior.ln_prob(self.xp.asarray([1.1, 2.2, 2.2, 300, 150])) - == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) - ) - ) + ln_probs = discrete_value_prior.ln_prob(self.xp.asarray([1.1, 2.2, 2.2, 300, 150])) + self.assertEqual(aac.get_namespace(ln_probs), self.xp) + np.testing.assert_array_equal(np.asarray(ln_probs), np.array([-np.log(N)] * 4 + [-np.inf])) @pytest.mark.array_backend diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index 84999d42c..469c53ece 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -160,6 +160,8 @@ def test_prob_integrate_to_one(self): n_samples = 1000000 samples = self.priors.sample_subset(keys=keys, size=n_samples, xp=self.xp) prob = self.priors.prob(samples, axis=0) + self.assertEqual(aac.get_namespace(prob), self.xp) + prob = np.asarray(prob) dm1 = self.priors["a"].maximum - self.priors["a"].minimum dm2 = self.priors["b"].maximum - self.priors["b"].minimum prior_volume = (dm1 * dm2) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index c24ae1118..a3165adce 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -254,6 +254,11 @@ def condition_func(reference_params, test_param): dist=hp_3d_dist, name="testdistance", unit="unit" ), ] + if aac.is_torch_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.Interped) + ] def tearDown(self): del self.priors @@ -268,29 +273,35 @@ def test_minimum_rescaling(self): # the edge of the prior is extremely suppressed for these priors # and so the rescale function doesn't quite return the lower bound continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue elif bilby.core.prior.JointPrior in prior.__class__.__mro__: minimum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): - self.assertAlmostEqual(minimum_sample[0], prior.minimum) - self.assertAlmostEqual(minimum_sample[1], prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample[0]), prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample[1]), prior.minimum) else: minimum_sample = prior.rescale(self.xp.asarray(0)) - self.assertAlmostEqual(minimum_sample, prior.minimum) + self.assertAlmostEqual(np.asarray(minimum_sample), prior.minimum) def test_maximum_rescaling(self): """Test the the rescaling works as expected.""" for prior in self.priors: + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if bilby.core.prior.JointPrior in prior.__class__.__mro__: maximum_sample = prior.rescale(self.xp.asarray(0)) if prior.dist.filled_rescale(): - self.assertAlmostEqual(maximum_sample[0], prior.maximum) - self.assertAlmostEqual(maximum_sample[1], prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample[0]), prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample[1]), prior.maximum) elif isinstance(prior, bilby.gw.prior.AlignedSpin): maximum_sample = prior.rescale(self.xp.asarray(1)) - self.assertGreater(maximum_sample, 0.997) + self.assertGreater(np.asarray(maximum_sample), 0.997) else: maximum_sample = prior.rescale(self.xp.asarray(1)) - self.assertAlmostEqual(maximum_sample, prior.maximum) + self.assertAlmostEqual(np.asarray(maximum_sample), prior.maximum) def test_many_sample_rescaling(self): """Test the the rescaling works as expected.""" @@ -298,6 +309,9 @@ def test_many_sample_rescaling(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue many_samples = prior.rescale(self.xp.asarray(np.random.uniform(0, 1, 1000))) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): @@ -383,12 +397,12 @@ def test_prob_and_ln_prob(self): # the prob and ln_prob functions, it must be ignored in this test. lnprob = prior.ln_prob(sample) prob = prior.prob(sample) - # lower precision for jax running tests with float32 - self.assertAlmostEqual( - self.xp.log(prob), lnprob, 6 - ) self.assertEqual(aac.get_namespace(lnprob), self.xp) self.assertEqual(aac.get_namespace(prob), self.xp) + # lower precision for jax running tests with float32 + lnprob = np.asarray(lnprob) + prob = np.asarray(prob) + self.assertAlmostEqual(np.log(prob), lnprob, 6) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: @@ -431,6 +445,9 @@ def test_cdf_is_inverse_of_rescaling(self): def test_cdf_one_above_domain(self): for prior in self.priors: + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if prior.maximum != np.inf: outside_domain = self.xp.linspace( prior.maximum + 1, prior.maximum + 1e4, 1000 @@ -442,6 +459,9 @@ def test_cdf_zero_below_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue + if isinstance(prior, bilby.gw.prior.HealPixPrior) and aac.is_torch_namespace(self.xp): + # HealPix rescaling requires interpolation + continue if ( bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf @@ -604,7 +624,12 @@ def test_probability_in_domain(self): else: maximum = prior.maximum domain = self.xp.linspace(minimum, maximum, 1000) - self.assertTrue(all(prior.prob(domain) >= 0)) + print(prior) + prob = prior.prob(domain) + print(min(prob)) + self.assertEqual(aac.get_namespace(prob), self.xp) + prob = np.asarray(prob) + self.assertTrue(all(prob >= 0)) def test_probability_surrounding_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" @@ -670,7 +695,9 @@ def test_normalized(self): domain = np.linspace(prior.minimum, prior.maximum, 10000) elif isinstance(prior, bilby.core.prior.WeightedDiscreteValues): domain = prior.values - self.assertTrue(np.sum(prior.prob(self.xp.asarray(domain))) == 1) + probs = prior.prob(self.xp.asarray(domain)) + self.assertEqual(aac.get_namespace(probs), self.xp) + self.assertTrue(np.sum(np.asarray(probs)) == 1) continue else: domain = np.linspace(prior.minimum, prior.maximum, 1000) @@ -831,6 +858,7 @@ def test_repr(self): repr_prior_string = repr_prior_string.replace( "HealPixMapPriorDist", "bilby.gw.prior.HealPixMapPriorDist" ) + prior.dist.rescale_parameters = {key: None for key in prior.dist.names} elif isinstance(prior, bilby.gw.prior.UniformComovingVolume): repr_prior_string = "bilby.gw.prior." + repr(prior) elif "Conditional" in prior.__class__.__name__: @@ -892,24 +920,6 @@ def test_set_minimum_setting(self): prior.minimum = (prior.maximum + prior.minimum) / 2 self.assertTrue(min(prior.sample(10000, xp=self.xp)) > prior.minimum) - # def test_jax_methods(self): - # import jax - - # points = jax.numpy.linspace(1e-3, 1 - 1e-3, 10) - # for prior in self.priors: - # if bilby.core.prior.JointPrior in prior.__class__.__mro__: - # continue - # scaled = prior.rescale(points) - # assert isinstance(scaled, jax.Array) - # if isinstance(prior, bilby.core.prior.DeltaFunction): - # continue - # probs = prior.prob(scaled) - # assert min(probs) > 0 - # assert max(abs(jax.numpy.log(probs) - prior.ln_prob(scaled))) < 1e-6 - # if isinstance(prior, bilby.core.prior.WeightedDiscreteValues): - # continue - # assert max(abs(prior.cdf(scaled) - points)) < 1e-6 - if __name__ == "__main__": unittest.main() diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index 1ec76ab71..7c5716b8a 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -68,10 +68,10 @@ def test_set_spike_height_domain_edge(self): class TestSlabSpikeClasses(unittest.TestCase): def setUp(self): - self.minimum = 0.4 - self.maximum = 2.4 + self.minimum = self.xp.asarray(0.4) + self.maximum = self.xp.asarray(2.4) self.spike_loc = self.xp.asarray(1.5) - self.spike_height = 0.3 + self.spike_height = self.xp.asarray(0.3) self.slabs = [ Uniform(minimum=self.minimum, maximum=self.maximum), @@ -80,8 +80,8 @@ def setUp(self): TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1), Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1), Gaussian(mu=0, sigma=1), - Cosine(), - Sine(), + Cosine(minimum=self.xp.asarray(-np.pi / 2), maximum=self.xp.asarray(np.pi / 2)), + Sine(minimum=self.xp.asarray(0), maximum=self.xp.asarray(np.pi)), HalfGaussian(sigma=1), LogNormal(mu=1, sigma=2), Exponential(mu=2), @@ -189,8 +189,11 @@ def test_rescale_no_spike(self): vals = self.xp.linspace(0, 1, 1000) expected = slab.rescale(vals) actual = slab_spike.rescale(vals) - self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) self.assertEqual(aac.get_namespace(actual), self.xp) + self.assertEqual(aac.get_namespace(expected), self.xp) + actual = np.asarray(actual) + expected = np.asarray(expected) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) def test_rescale_below_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): diff --git a/test/core/series_test.py b/test/core/series_test.py index aec2ff42f..c2b8dccdb 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -1,5 +1,6 @@ import unittest +import array_api_compat as aac import numpy as np import pytest @@ -47,10 +48,10 @@ def test_start_time_from_init(self): self.assertEqual(self.start_time, self.series.start_time) def test_frequency_array_type(self): - self.assertIsInstance(self.series.frequency_array, self.xp.ndarray) + self.assertEqual(aac.get_namespace(self.series.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.series.time_array, self.xp.ndarray) + self.assertEqual(aac.get_namespace(self.series.time_array), self.xp) def test_frequency_array_from_init(self): expected = create_frequency_series( @@ -94,10 +95,10 @@ def test_time_array_setter(self): self.series.time_array = new_time_array self.assertTrue(np.array_equal(new_time_array, self.series.time_array)) self.assertAlmostEqual( - new_sampling_frequency, self.series.sampling_frequency, places=1 + np.asarray(new_sampling_frequency), np.asarray(self.series.sampling_frequency), places=1 ) - self.assertAlmostEqual(new_duration, self.series.duration, places=1) - self.assertAlmostEqual(new_start_time, self.series.start_time, places=1) + self.assertAlmostEqual(np.asarray(new_duration), np.asarray(self.series.duration), places=1) + self.assertAlmostEqual(np.asarray(new_start_time), np.asarray(self.series.start_time), places=1) def test_time_array_without_sampling_frequency(self): self.series.sampling_frequency = None diff --git a/test/core/utils_test.py b/test/core/utils_test.py index f766f2c74..d8a78beee 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -72,6 +72,9 @@ def test_nfft_sine_function(self): time_domain_strain, self.sampling_frequency ) frequency_at_peak = frequencies[xp.argmax(abs(frequency_domain_strain))] + self.assertEqual(aac.get_namespace(frequency_at_peak), xp) + frequency_at_peak = np.asarray(frequency_at_peak) + injected_frequency = np.asarray(injected_frequency) self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1) def test_nfft_infft(self): @@ -344,6 +347,8 @@ def plot(): @pytest.mark.usefixtures("xp_class") class TestUnsortedInterp2d(unittest.TestCase): def setUp(self): + if aac.is_torch_namespace(self.xp): + pytest.skip("Skipping Interp2d tests for torch backend") self.xx = np.linspace(0, 1, 10) self.yy = np.linspace(0, 1, 10) self.zz = np.random.random((10, 10)) From 38cc5f66bcb2cffccbfde566cbe7c6e900c1f325 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 08:23:56 -0500 Subject: [PATCH 093/110] FMT: run precommits --- bilby/core/grid.py | 2 +- bilby/core/prior/analytical.py | 4 +++- bilby/core/prior/dict.py | 2 +- bilby/core/prior/slabspike.py | 4 ---- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 4d1a7501c..12dfa2bf6 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -218,7 +218,7 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): elif aac.is_torch_namespace(xp): # https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440 out = xp.stack([ - logtrapzexp(x_i, dx=dx, xp=xp) for x_i in xp.unbind(log_array, dim=axis) + logtrapzexp(x_i, dx=dx, xp=xp) for x_i in xp.unbind(log_array, dim=axis) ], dim=min(axis, log_array.ndim - 2)) else: out = xp.apply_along_axis( diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index d0a9bc22e..02e65bf3a 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -140,7 +140,9 @@ def prob(self, val, *, xp=None): float: Prior probability of val """ if self.alpha == -1: - return xp.nan_to_num(1 / val / xp.log(xp.asarray(self.maximum / self.minimum))) * self.is_in_prior_range(val) + return xp.nan_to_num( + 1 / val / xp.log(xp.asarray(self.maximum / self.minimum)) + ) * self.is_in_prior_range(val) else: return xp.nan_to_num(val ** self.alpha * (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index d4d627299..dd8e586e6 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -614,7 +614,7 @@ def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): subsample = {key: sample[key][in_bounds] for key in sample} keep = xp.log(self.evaluate_constraints(subsample, xp=xp)) constrained_ln_prob = xpx.at(constrained_ln_prob, in_bounds).set( - ln_prob[in_bounds] + keep + xp.log(ratio) + ln_prob[in_bounds] + keep + xp.log(ratio) ) return constrained_ln_prob diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index ad4f118b8..2ac310f55 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -87,10 +87,6 @@ def rescale(self, val, *, xp=None): array_like: Associated prior value with input value. """ lower_indices = val < self.inverse_cdf_below_spike - intermediate_indices = ( - (self.inverse_cdf_below_spike <= val) - * (val < (self.inverse_cdf_below_spike + self.spike_height)) - ) higher_indices = val >= (self.inverse_cdf_below_spike + self.spike_height) slab_scaled = self._contracted_rescale( From 65029defda9a94b414def2d2ce04f168a67dff42 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 10:52:11 -0500 Subject: [PATCH 094/110] Make torch fully tested --- .github/workflows/unit-tests.yml | 2 +- bilby/gw/compat/__init__.py | 5 +++ bilby/gw/compat/torch.py | 19 ++++++++++ bilby/gw/conversion.py | 4 +-- bilby/gw/detector/calibration.py | 12 +++---- bilby/gw/detector/interferometer.py | 3 +- bilby/gw/detector/psd.py | 6 ++-- bilby/gw/geometry.py | 2 +- bilby/gw/likelihood/base.py | 6 ++-- bilby/gw/likelihood/multiband.py | 20 ++++++----- bilby/gw/likelihood/relative.py | 6 ++-- bilby/gw/likelihood/roq.py | 4 +-- bilby/gw/utils.py | 10 +++--- docs/array_api.rst | 18 ++++++++++ test/gw/conversion_test.py | 31 ++++++++-------- test/gw/detector/geometry_test.py | 12 +++---- test/gw/likelihood_test.py | 55 ++++++++++++++++------------- test/gw/prior_test.py | 6 ++-- test/gw/utils_test.py | 10 ++++-- test/gw/waveform_generator_test.py | 4 +-- 20 files changed, 149 insertions(+), 86 deletions(-) create mode 100644 bilby/gw/compat/torch.py diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 6bce425e2..1a6275fd3 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -66,7 +66,7 @@ jobs: pytest --array-backend jax --durations 10 - name: Run torch-backend unit tests run: | - pytest --array-backend torch --durations 10 test/core + pytest --array-backend torch --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/bilby/gw/compat/__init__.py b/bilby/gw/compat/__init__.py index 16eea2d00..36f2566c4 100644 --- a/bilby/gw/compat/__init__.py +++ b/bilby/gw/compat/__init__.py @@ -6,5 +6,10 @@ try: from .cython import gps_time_to_utc +except ModuleNotFoundError: + pass + +try: + from .torch import n_leap_seconds except ModuleNotFoundError: pass \ No newline at end of file diff --git a/bilby/gw/compat/torch.py b/bilby/gw/compat/torch.py new file mode 100644 index 000000000..b3958f347 --- /dev/null +++ b/bilby/gw/compat/torch.py @@ -0,0 +1,19 @@ +import torch +from plum import dispatch + +from ..time import ( + LEAP_SECONDS as _LEAP_SECONDS, + n_leap_seconds as _n_leap_seconds, +) + +__all__ = ["n_leap_seconds"] + +LEAP_SECONDS = torch.tensor(_LEAP_SECONDS) + + +@dispatch +def n_leap_seconds(date: torch.Tensor) -> torch.Tensor: + """ + Find the number of leap seconds required for the specified date. + """ + return _n_leap_seconds(date, LEAP_SECONDS) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index ad751c4fa..cc1b3e493 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -645,7 +645,7 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) model_space_mean = xp.asarray([0.89421, 0.33878, -0.07894, 0.00393]) model_space_standard_deviation = xp.asarray([0.35700, 0.25769, 0.05452, 0.00312]) converted_gamma_parameters = \ - model_space_mean + model_space_standard_deviation * xp.dot(transformation_matrix, sampled_pca_gammas) + model_space_mean + model_space_standard_deviation * (transformation_matrix @ sampled_pca_gammas) return converted_gamma_parameters @@ -1046,7 +1046,7 @@ def component_masses_to_symmetric_mass_ratio(mass_1, mass_2): Symmetric mass ratio of the binary """ xp = array_module(mass_1) - return xp.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, 1 / 4) + return xp.minimum((mass_1 * mass_2) / (mass_1 + mass_2) ** 2, xp.asarray(0.25)) def component_masses_to_mass_ratio(mass_1, mass_2): diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 883275016..a4cebffe6 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -48,7 +48,7 @@ from array_api_compat import is_jax_namespace from scipy.interpolate import interp1d -from ...compat.utils import array_module +from ...compat.utils import array_module, xp_wrap from ...core.utils.log import logger from ...core.prior.dict import PriorDict from ..prior import CalibrationPriorDict @@ -345,7 +345,8 @@ def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): + d * spline_coefficients[next_nodes] ) - def get_calibration_factor(self, frequency_array, **params): + @xp_wrap + def get_calibration_factor(self, frequency_array, *, xp=np, **params): """Apply calibration model Parameters @@ -363,11 +364,11 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - log10f_per_deltalog10f = np.nan_to_num( - np.log10(frequency_array) - self.log_spline_points[0], + log10f_per_deltalog10f = xp.nan_to_num( + xp.log10(frequency_array) - xp.asarray(self.log_spline_points[0]), neginf=0.0, ) / self.delta_log_spline_points - previous_nodes = np.clip(np.floor(log10f_per_deltalog10f).astype(int), a_min=0, a_max=self.n_points - 2) + previous_nodes = xp.clip(xp.astype(log10f_per_deltalog10f, int), min=0, max=self.n_points - 2) b = log10f_per_deltalog10f - previous_nodes a = 1 - b c = (a**3 - a) / 6 @@ -378,7 +379,6 @@ def get_calibration_factor(self, frequency_array, **params): delta_amplitude = self._evaluate_spline("amplitude", a, b, c, d, previous_nodes) delta_phase = self._evaluate_spline("phase", a, b, c, d, previous_nodes) calibration_factor = (1 + delta_amplitude) * (2 + 1j * delta_phase) / (2 - 1j * delta_phase) - xp = aac.get_namespace(calibration_factor) return xp.nan_to_num(calibration_factor) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 6f5857132..bab36f05a 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -313,13 +313,14 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= used to set the time at which the antenna response is evaluated, otherwise the provided :code:`Parameters["geocent_time"]` is used. """ + xp = array_module(waveform_polarizations) if frequencies is None: # frequencies = self.frequency_array[self.frequency_mask] frequencies = self.frequency_array mask = self.frequency_mask else: - xp = array_module(frequencies) mask = xp.ones(len(frequencies), dtype=bool) + frequencies = xp.asarray(frequencies) if self.reference_time is None: antenna_time = parameters["geocent_time"] diff --git a/bilby/gw/detector/psd.py b/bilby/gw/detector/psd.py index a3948f966..e3fe7091a 100644 --- a/bilby/gw/detector/psd.py +++ b/bilby/gw/detector/psd.py @@ -3,6 +3,7 @@ import numpy as np from scipy.interpolate import interp1d +from ...compat.utils import xp_wrap from ...core import utils from ...core.utils import logger from .strain_data import InterferometerStrainData @@ -341,7 +342,8 @@ def __import_power_spectral_density(self): """ Automagically load a power spectral density curve """ self.frequency_array, self.psd_array = np.genfromtxt(self.psd_file).T - def get_noise_realisation(self, sampling_frequency, duration): + @xp_wrap + def get_noise_realisation(self, sampling_frequency, duration, *, xp=None): """ Generate frequency Gaussian noise scaled to the power spectral density. @@ -363,4 +365,4 @@ def get_noise_realisation(self, sampling_frequency, duration): frequency_domain_strain = self.__power_spectral_density_interpolated(frequencies) ** 0.5 * white_noise out_of_bounds = (frequencies < min(self.frequency_array)) | (frequencies > max(self.frequency_array)) frequency_domain_strain[out_of_bounds] = 0 * (1 + 1j) - return frequency_domain_strain, frequencies + return xp.asarray(frequency_domain_strain), xp.asarray(frequencies) diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index 54d2f3a1d..68321d4b4 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -181,7 +181,7 @@ def time_delay_from_geocenter(detector1, ra, dec, time): def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): """""" xp = array_module(delta_x) - omega_prime = xp.asarray( + omega_prime = xp.stack( [ xp.sin(zenith) * xp.cos(azimuth), xp.sin(zenith) * xp.sin(azimuth), diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 4c288d9c4..bc8296915 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -304,7 +304,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr elif self.time_marginalization and self.calibration_marginalization: d_inner_h_integrand = np.tile( - interferometer.frequency_domain_strain.conjugate() * signal / + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T @@ -324,14 +324,14 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr elif self.time_marginalization and not self.calibration_marginalization: d_inner_h_array = normalization * np.fft.fft( signal[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1] ) elif self.calibration_marginalization and ('recalib_index' not in parameters): d_inner_h_integrand = ( normalization * - interferometer.frequency_domain_strain.conjugate() * signal + interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array ) d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index aeec61387..69958a816 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -535,7 +535,7 @@ def _setup_linear_coefficients(self): logger.info("Pre-computing linear coefficients for {}".format(ifo.name)) fddata = np.zeros(N // 2 + 1, dtype=complex) fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += \ - ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + np.asarray(ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]) for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) @@ -757,14 +757,14 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr modes, parameters, frequencies=self.banded_frequency_points ) - d_inner_h = (strain @ self.linear_coeffs[interferometer.name]).conjugate() + d_inner_h = (strain @ self.linear_coeffs[interferometer.name]).conj() xp = array_module(strain) if self.linear_interpolation: optimal_snr_squared = xp.vdot( xp.abs(strain)**2, - self.quadratic_coeffs[interferometer.name] + xp.asarray(self.quadratic_coeffs[interferometer.name]) ) else: optimal_snr_squared = 0. @@ -775,16 +775,20 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr if b == 0: optimal_snr_squared += (4. / self.interferometers.duration) * xp.vdot( xp.abs(strain[start_idx:end_idx + 1])**2, - interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] + interferometer.frequency_mask[Ks:Ke + 1] * xp.asarray(self.windows[start_idx:end_idx + 1]) / interferometer.power_spectral_density_array[Ks:Ke + 1]) else: self.wths[interferometer.name][b][Ks:Ke + 1] = ( - self.square_root_windows[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1] + xp.asarray(self.square_root_windows[start_idx:end_idx + 1]) + * strain[start_idx:end_idx + 1] ) - self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft(self.wths[interferometer.name][b]) - thbc = xp.fft.rfft(self.hbcs[interferometer.name][b]) + self.hbcs[interferometer.name][b][-Mb:] = xp.fft.irfft( + xp.asarray(self.wths[interferometer.name][b]) + ) + thbc = xp.fft.rfft(xp.asarray(self.hbcs[interferometer.name][b])) + print(self.Ibcs[interferometer.name][b]) optimal_snr_squared += (4. / self.Tbhats[b]) * xp.vdot( - xp.abs(thbc)**2, self.Ibcs[interferometer.name][b]) + xp.abs(thbc)**2, xp.asarray(self.Ibcs[interferometer.name][b].real)) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index d40015219..5a7b2a539 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -398,8 +398,8 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr parameters=parameters, ) a0, a1, b0, b1 = self.summary_data[interferometer.name] - d_inner_h = (a0 * r0.conjugate() + a1 * r1.conjugate()).sum() - h_inner_h = (b0 * abs(r0) ** 2 + 2 * b1 * (r0 * r1.conjugate()).real).sum() + d_inner_h = (a0 * r0.conj() + a1 * r1.conj()).sum() + h_inner_h = (b0 * abs(r0) ** 2 + 2 * b1 * (r0 * r1.conj()).real).sum() optimal_snr_squared = h_inner_h complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) @@ -412,7 +412,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ) d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( full_waveform[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] + * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1]) else: diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 9e9bdb28e..b5d84bafa 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -491,7 +491,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr indices = xp.clip(indices, 0, len(self.weights['time_samples']) - 1) d_inner_h_tc_array = xp.einsum( 'i,ji->j', - xp.conjugate(h_linear), + xp.conj(h_linear), xp.asarray( self.weights[interferometer.name + '_linear'][self.basis_number_linear] )[indices], @@ -601,7 +601,7 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time # spacing is larger than 5 times the ROQ time spacing. weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] - h_linear_conj = np.conjugate(h_linear) + h_linear_conj = np.conj(h_linear) if (times[1] - times[0]) / roq_time_space > 5: d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 13879b37d..bbec47fa3 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -108,11 +108,11 @@ def inner_product(aa, bb, frequency, PSD): psd_interp = PSD.power_spectral_density_interpolated(frequency) # calculate the inner product - integrand = np.conj(aa) * bb / psd_interp + integrand = (aa.conj() * bb / psd_interp).real df = frequency[1] - frequency[0] - integral = np.sum(integrand) * df - return 4. * np.real(integral) + integral = integrand.sum() * df + return 4. * integral def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): @@ -134,7 +134,7 @@ def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): ======= Noise-weighted inner product. """ - integrand = aa.conjugate() * bb / power_spectral_density + integrand = aa.conj() * bb / power_spectral_density return 4 / duration * integrand.sum() @@ -223,7 +223,7 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non """ low_index = int(lower_cut_off / delta_frequency) up_index = int(upper_cut_off / delta_frequency) - integrand = signal_a.conjugate() * signal_b + integrand = signal_a.conj() * signal_b integrand = integrand[low_index:up_index] / power_spectral_density[low_index:up_index] integral = (4 * delta_frequency * integrand) / norm_a / norm_b return sum(integral).real diff --git a/docs/array_api.rst b/docs/array_api.rst index 6ce77c28a..8ce1cc043 100644 --- a/docs/array_api.rst +++ b/docs/array_api.rst @@ -27,6 +27,9 @@ Bilby is currently tested with the following array backends: - **NumPy** (default): Standard CPU-based computations - **JAX**: GPU/TPU acceleration and automatic differentiation +- **PyTorch**: GPU acceleration and deep learning integration. + :code:`PyTorch` support is not complete, for example, functionality + requiring interpolation is not available. While :code:`Bilby` should be compatible with other Array API compliant libraries, these are not currently tested or officially supported. @@ -213,6 +216,12 @@ and the analysis is then performed using the JIT-compiled likelihood. However, there is currently a performance issue with the distance marginalized likelihood using the :code:`JAX` backend. +.. warning:: + + Some array backends (notably :code:`torch`) are more picky than others about data types. + For maximal consistency, try to consistently pass zero-dimensional arrays rather than :code:`Python` + scalars, e.g., :code:`torch.array(1.0)` instead of :code:`1.0`. + Performance Considerations -------------------------- @@ -314,6 +323,15 @@ The ``@xp_wrap`` decorator: - Falls back to NumPy when the input is a standard Python type or NumPy array - Handles the conversion seamlessly so users don't need to specify ``xp`` +Missing functionality +--------------------- + +The most significant missing functionality is the lack of a consistent random number generation +interface across different array backends. +Currently, all random calls use :code:`numpy.random` with the seed specified as described in :doc:`rng`. +This means that functionality like prior sampling and generating noise realizations in gravitational-wave +detectors will not be :code:`JIT`-compatible. + For Bilby Developers ===================== diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index d0ce869b7..8998924c3 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -86,28 +86,28 @@ def test_chirp_mass_and_primary_mass_to_mass_ratio(self): mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( self.chirp_mass, self.mass_1 ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_symmetric_mass_ratio_to_mass_ratio(self): mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( self.symmetric_mass_ratio ) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): symmetric_mass_ratio = conversion.chirp_mass_and_total_mass_to_symmetric_mass_ratio( self.chirp_mass, self.total_mass ) - self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) + self.assertAlmostEqual(float(self.symmetric_mass_ratio), float(symmetric_mass_ratio)) self.assertEqual(aac.get_namespace(symmetric_mass_ratio), self.xp) def test_chirp_mass_and_mass_ratio_to_total_mass(self): total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( self.chirp_mass, self.mass_ratio ) - self.assertAlmostEqual(self.total_mass, total_mass) + self.assertAlmostEqual(float(self.total_mass), float(total_mass)) self.assertEqual(aac.get_namespace(total_mass), self.xp) def test_chirp_mass_and_mass_ratio_to_component_masses(self): @@ -138,7 +138,7 @@ def test_component_masses_to_symmetric_mass_ratio(self): def test_component_masses_to_mass_ratio(self): mass_ratio = conversion.component_masses_to_mass_ratio(self.mass_1, self.mass_2) - self.assertAlmostEqual(self.mass_ratio, mass_ratio) + self.assertAlmostEqual(float(self.mass_ratio), float(mass_ratio)) self.assertEqual(aac.get_namespace(mass_ratio), self.xp) def test_mass_1_and_chirp_mass_to_mass_ratio(self): @@ -656,8 +656,8 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): self.assertTrue("mass_2" in local_test_vars_with_component_masses.keys()) for key in local_test_vars_with_component_masses.keys(): self.assertAlmostEqual( - local_test_vars_with_component_masses[key], - self.expected_values[key]) + np.asarray(local_test_vars_with_component_masses[key]), + np.asarray(self.expected_values[key])) # Test the function more generally local_all_mass_parameters = \ @@ -687,7 +687,10 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): ) ) for key in local_all_mass_parameters.keys(): - self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) + self.assertAlmostEqual( + np.asarray(expected_values[key]), + np.asarray(local_all_mass_parameters[key]), + ) self.assertEqual( aac.get_namespace(local_all_mass_parameters[key]), self.xp, @@ -910,10 +913,10 @@ def test_spectral_pca_to_spectral(self): self.spectral_pca_gamma_2[i], self.spectral_pca_gamma_3[i] ) - self.assertAlmostEqual(spectral_gamma_0, self.spectral_gamma_0[i], places=5) - self.assertAlmostEqual(spectral_gamma_1, self.spectral_gamma_1[i], places=5) - self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) - self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_0), self.spectral_gamma_0[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_1), self.spectral_gamma_1[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_2), self.spectral_gamma_2[i], places=5) + self.assertAlmostEqual(float(spectral_gamma_3), self.spectral_gamma_3[i], places=5) for val in [spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3]: self.assertEqual(aac.get_namespace(val), self.xp) @@ -943,8 +946,8 @@ def test_spectral_params_to_lambda_1_lambda_2(self): self.mass_1_source_spectral[i], self.mass_2_source_spectral[i] ) - self.assertAlmostEqual(self.lambda_1_spectral[i], lambda_1, places=0) - self.assertAlmostEqual(self.lambda_2_spectral[i], lambda_2, places=0) + self.assertAlmostEqual(self.lambda_1_spectral[i], float(lambda_1), places=0) + self.assertAlmostEqual(self.lambda_2_spectral[i], float(lambda_2), places=0) self.assertAlmostEqual(self.eos_check_spectral[i], eos_check) def test_polytrope_or_causal_params_to_lambda_1_lambda_2_causal(self): diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 231b82b17..7340a5f8d 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -154,37 +154,37 @@ def test_y_with_latitude_update(self): def test_detector_tensor_with_x_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.xarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_azimuth_update(self): original = self.geometry.detector_tensor self.geometry.yarm_azimuth += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_x_tilt_update(self): original = self.geometry.detector_tensor self.geometry.xarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_y_tilt_update(self): original = self.geometry.detector_tensor self.geometry.yarm_tilt += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_longitude_update(self): original = self.geometry.detector_tensor self.geometry.longitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_detector_tensor_with_latitude_update(self): original = self.geometry.detector_tensor self.geometry.latitude += 1 - self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0) + self.assertNotEqual(np.max(np.asarray(abs(self.geometry.detector_tensor - original))), 0) self.assertEqual(aac.get_namespace(self.geometry.detector_tensor), self.xp) def test_unit_vector_along_arm_default(self): diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index c2dcf1529..73e6660a2 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -34,12 +34,14 @@ def convert_nested_dict(self, data): else: raise ValueError("Input must be an array API object or a dict of such objects.") - def frequency_domain_strain(self, parameters): - wf = self.wfg.frequency_domain_strain(parameters) - return self.convert_nested_dict(wf) + def _strain_from_model(self, model_data_points, model, parameters): + # we can't pass a frequency array through as a torch array + model_data_points = np.asarray(model_data_points) + return super()._strain_from_model(model_data_points, model, parameters) - def time_domain_strain(self, parameters): - wf = self.wfg.time_domain_strain(parameters) + def frequency_domain_strain(self, parameters): + self.wfg.frequency_array = np.asarray(self.wfg.frequency_array) + wf = self.wfg.__class__.frequency_domain_strain(self, parameters) return self.convert_nested_dict(wf) @@ -59,15 +61,15 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, + psi=self.xp.asarray(2.659), phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=2048, duration=4 + sampling_frequency=self.xp.asarray(2048.0), duration=self.xp.asarray(4.0) ) self.interferometers.set_array_backend(self.xp) base_wfg = bilby.gw.waveform_generator.GWSignalWaveformGenerator( @@ -100,7 +102,7 @@ def test_noise_log_likelihood(self): def test_log_likelihood(self): """Test log likelihood matches precomputed value""" logl = self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(logl, -4032.4397343470005, 3) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): @@ -144,11 +146,11 @@ def setUp(self): phi_jl=0.3, luminosity_distance=4000.0, theta_jn=0.4, - psi=2.659, - phase=1.3, - geocent_time=1126259642.413, - ra=1.375, - dec=-1.2108, + psi=self.xp.asarray(2.659), + phase=self.xp.asarray(1.3), + geocent_time=self.xp.asarray(1126259642.413), + ra=self.xp.asarray(1.375), + dec=self.xp.asarray(-1.2108), ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) self.interferometers.set_strain_data_from_power_spectral_densities( @@ -184,23 +186,23 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" nll = self.likelihood.noise_log_likelihood() + self.assertEqual(aac.get_namespace(nll), self.xp) self.assertAlmostEqual( - -4014.1787704539474, nll, 3 + -4014.1787704539474, float(nll), 3 ) - self.assertEqual(aac.get_namespace(nll), self.xp) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" logl = self.likelihood.log_likelihood(self.parameters) - self.assertAlmostEqual(logl, -4032.4397343470005, 3) + self.assertAlmostEqual(float(logl), -4032.4397343470005, 3) self.assertEqual(aac.get_namespace(logl), self.xp) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" llr = self.likelihood.log_likelihood_ratio(self.parameters) self.assertAlmostEqual( - self.likelihood.log_likelihood(self.parameters) - self.likelihood.noise_log_likelihood(), - llr, + float(self.likelihood.log_likelihood(self.parameters)) - float(self.likelihood.noise_log_likelihood()), + float(llr), 3, ) self.assertEqual(aac.get_namespace(llr), self.xp) @@ -1365,12 +1367,13 @@ def setUp(self): phi_jl=0.0, luminosity_distance=200.0, theta_jn=0.4, - psi=0.659, + psi=self.xp.asarray(0.659), phase=1.3, - geocent_time=1187008882, - ra=1.3, - dec=-1.2 + geocent_time=self.xp.asarray(1187008882), + ra=self.xp.asarray(1.3), + dec=self.xp.asarray(-1.2) ) # Network SNR is ~50 + # torch needs sky parameters to be tensors self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) bilby.core.utils.random.seed(70817) @@ -1598,6 +1601,7 @@ def test_inout_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg = BackendWaveformGenerator(wfg, self.xp) self.ifos.inject_signal( parameters=self.test_parameters, waveform_generator=wfg ) @@ -1609,6 +1613,7 @@ def test_inout_weights(self, linear_interpolation): reference_frequency=self.fmin, waveform_approximant=waveform_approximant ) ) + wfg_mb = BackendWaveformGenerator(wfg_mb, self.xp) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg_mb, reference_chirp_mass=self.test_parameters['chirp_mass'], diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index aec6f01b3..277235a62 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -566,7 +566,8 @@ def test_luminosity_distance_to_comoving_distance(self): @pytest.mark.usefixtures("xp_class") class TestAlignedSpin(unittest.TestCase): def setUp(self): - pass + if aac.is_torch_namespace(self.xp): + self.skipTest("Torch doesn't support interpolated priors.") def test_default_prior_matches_analytic(self): prior = bilby.gw.prior.AlignedSpin() @@ -591,7 +592,8 @@ def test_non_analytic_form_has_correct_statistics(self): class TestConditionalChiUniformSpinMagnitude(unittest.TestCase): def setUp(self): - pass + if aac.is_torch_namespace(self.xp): + self.skipTest("Torch doesn't support interpolated priors.") def test_marginalized_prior_is_uniform(self): priors = bilby.gw.prior.BBHPriorDict(aligned_spin=True) diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index 80b8fe60d..b4bb1af8e 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -33,15 +33,19 @@ def test_asd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) - self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) self.assertEqual(aac.get_namespace(asd), self.xp) + asd = np.asarray(asd) + freq_data = np.asarray(freq_data) + self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) def test_psd_from_freq_series(self): freq_data = self.xp.asarray([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) - self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) self.assertEqual(aac.get_namespace(psd), self.xp) + psd = np.asarray(psd) + freq_data = np.asarray(freq_data) + self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) def test_inner_product(self): aa = self.xp.asarray([1, 2, 3]) @@ -99,8 +103,8 @@ def test_overlap(self): norm_a=gwutils.optimal_snr_squared(signal, psd, duration), norm_b=gwutils.optimal_snr_squared(frequency_domain_strain, psd, duration), ) - self.assertAlmostEqual(overlap, 2.76914407e-05) self.assertEqual(aac.get_namespace(overlap), self.xp) + self.assertAlmostEqual(float(overlap), 2.76914407e-05) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index 8a98369b5..70e48aa83 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -139,10 +139,10 @@ def test_source_model(self): ) def test_frequency_array_type(self): - self.assertIsInstance(self.waveform_generator.frequency_array, self.xp.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.frequency_array), self.xp) def test_time_array_type(self): - self.assertIsInstance(self.waveform_generator.time_array, self.xp.ndarray) + self.assertEqual(aac.array_namespace(self.waveform_generator.time_array), self.xp) def test_source_model_parameters(self): self.waveform_generator.parameters = self.simulation_parameters.copy() From 1759f6ace956d534e5968975b80ea8137965fcee Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 10:56:47 -0500 Subject: [PATCH 095/110] FMT: pre-commit fix --- bilby/gw/likelihood/multiband.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 69958a816..dd0fe5ad0 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -534,8 +534,10 @@ def _setup_linear_coefficients(self): for ifo in self.interferometers: logger.info("Pre-computing linear coefficients for {}".format(ifo.name)) fddata = np.zeros(N // 2 + 1, dtype=complex) - fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += \ - np.asarray(ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]) + fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += np.asarray( + ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) + for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) From b74b838b4045fd8489feb739fd0b7ae880932a4e Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 11:37:39 -0500 Subject: [PATCH 096/110] TEST: fix torch roq tests --- bilby/gw/likelihood/multiband.py | 2 +- bilby/gw/likelihood/roq.py | 6 +++--- test/gw/likelihood_test.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index dd0fe5ad0..6f80b1537 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -537,7 +537,7 @@ def _setup_linear_coefficients(self): fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += np.asarray( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] ) - + for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index b5d84bafa..55a980f06 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -661,17 +661,17 @@ def perform_roq_params_check(self, ifo=None): except ValueError: roq_minimum_component_mass = None - if ifo.maximum_frequency > roq_maximum_frequency: + if float(ifo.maximum_frequency) > roq_maximum_frequency: raise BilbyROQParamsRangeError( "Requested maximum frequency {} larger than ROQ basis fhigh {}" .format(ifo.maximum_frequency, roq_maximum_frequency) ) - if ifo.minimum_frequency < roq_minimum_frequency: + if float(ifo.minimum_frequency) < roq_minimum_frequency: raise BilbyROQParamsRangeError( "Requested minimum frequency {} lower than ROQ basis flow {}" .format(ifo.minimum_frequency, roq_minimum_frequency) ) - if ifo.strain_data.duration != roq_segment_length: + if float(ifo.strain_data.duration) != roq_segment_length: raise BilbyROQParamsRangeError( "Requested duration differs from ROQ basis seglen") diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 73e6660a2..f5b991d9f 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -354,11 +354,11 @@ def setUp(self): phi_jl=0.3, luminosity_distance=1000.0, theta_jn=0.4, - psi=0.659, + psi=self.xp.asarray(0.659), phase=1.3, - geocent_time=1.2, - ra=1.3, - dec=-1.2, + geocent_time=self.xp.asarray(1.2), + ra=self.xp.asarray(1.3), + dec=self.xp.asarray(-1.2), ) ifos = bilby.gw.detector.InterferometerList(["H1"]) @@ -732,11 +732,11 @@ def setUp(self): chi_2=0.0, luminosity_distance=100.0, theta_jn=0.4, - psi=0.659, + psi=self.xp.asarray(0.659), phase=1.3, - geocent_time=1.2, - ra=1.3, - dec=-1.2 + geocent_time=self.xp.asarray(1.2), + ra=self.xp.asarray(1.3), + dec=self.xp.asarray(-1.2) ) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") From f734a33b51d532535a76baefa09ff0434684d3ff Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 11:47:07 -0500 Subject: [PATCH 097/110] CI: prioritize torch tests --- .github/workflows/unit-tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 1a6275fd3..153a6b574 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -57,6 +57,9 @@ jobs: # - name: Run precommits # run: | # pre-commit run --all-files --verbose --show-diff-on-failure + - name: Run torch-backend unit tests + run: | + pytest --array-backend torch --durations 10 - name: Run unit tests run: | python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml @@ -64,9 +67,6 @@ jobs: run: | python -m pip install .[jax] pytest --array-backend jax --durations 10 - - name: Run torch-backend unit tests - run: | - pytest --array-backend torch --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v From 602a48d4e444798598ed5e9729a3b6e56de87ce9 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 11:56:22 -0500 Subject: [PATCH 098/110] TEST: another attempt to fix torch tests --- .github/workflows/unit-tests.yml | 4 ++-- bilby/gw/likelihood/roq.py | 2 +- test/conftest.py | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 153a6b574..a0b81d424 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -59,14 +59,14 @@ jobs: # pre-commit run --all-files --verbose --show-diff-on-failure - name: Run torch-backend unit tests run: | - pytest --array-backend torch --durations 10 + SCIPY_ARRAY_API=1 pytest --array-backend torch --durations 10 - name: Run unit tests run: | python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml - name: Run jax-backend unit tests run: | python -m pip install .[jax] - pytest --array-backend jax --durations 10 + SCIPY_ARRAY_API=1 pytest --array-backend jax --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 55a980f06..9c97fdaba 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -736,7 +736,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): - self.interferometers.start_time ) / time_space)) ) - self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space + self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * float(time_space) logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) # select bases to be used, set prior ranges and frequency nodes if exist diff --git a/test/conftest.py b/test/conftest.py index a9be96548..83e7a89a6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -42,10 +42,8 @@ def _xp(request): case None | "numpy": import numpy as xp case "jax" | "jax.numpy": - import os import jax - os.environ["SCIPY_ARRAY_API"] = "1" jax.config.update("jax_enable_x64", True) xp = jax.numpy case "torch": From ebb5ef2050173d88d4f7db1526dcae8a9a8706a7 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 12:13:20 -0500 Subject: [PATCH 099/110] Another attempt at fixing torch ROQ tests --- bilby/gw/likelihood/roq.py | 4 ++-- test/gw/likelihood_test.py | 12 ++++++------ test/gw/utils_test.py | 6 ++++-- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 9c97fdaba..ff691e1b8 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -873,7 +873,7 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] - linear_weights *= 4. * number_of_time_samples / self.interferometers.duration + linear_weights *= 4. * number_of_time_samples / float(self.interferometers.duration) self.weights[ifo.name + '_linear'].append(linear_weights) if pyfftw is not None: pyfftw.forget_wisdom() @@ -913,7 +913,7 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): )[start_frequency_bin:end_frequency_bin + 1] start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( - 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb)) + 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - float(ifo.duration) + Tb)) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index f5b991d9f..9fc7eb9ab 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -770,11 +770,11 @@ def test_fails_with_frequency_duration_mismatch( ) interferometers.set_array_backend(self.xp) for ifo in interferometers: - ifo.minimum_frequency = minimum_frequency - ifo.maximum_frequency = maximum_frequency + ifo.minimum_frequency = self.xp.asarray(minimum_frequency) + ifo.maximum_frequency = self.xp.asarray(maximum_frequency) search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( - duration=duration, - sampling_frequency=2 * maximum_frequency, + duration=self.xp.asarray(duration), + sampling_frequency=self.xp.asarray(2 * maximum_frequency), frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, waveform_arguments=dict( reference_frequency=self.reference_frequency, @@ -946,9 +946,9 @@ def assertLess_likelihood_errors( if minimum_frequency is None: ifo.minimum_frequency = self.minimum_frequency else: - ifo.minimum_frequency = minimum_frequency + ifo.minimum_frequency = self.xp.asarray(minimum_frequency) if maximum_frequency is not None: - ifo.maximum_frequency = maximum_frequency + ifo.maximum_frequency = self.xp.asarray(maximum_frequency) interferometers.set_strain_data_from_zero_noise( sampling_frequency=self.sampling_frequency, duration=self.duration, diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index b4bb1af8e..2fc700993 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -335,10 +335,12 @@ def test_conversion_gives_correct_prior(self) -> None: ras, decs = bilby.gw.utils.zenith_azimuth_to_ra_dec( zeniths, azimuths, times, self.ifos ) - self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) - self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) self.assertEqual(aac.get_namespace(ras), self.xp) self.assertEqual(aac.get_namespace(decs), self.xp) + ras = np.asarray(ras) + decs = np.asarray(decs) + self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) + self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) @pytest.mark.array_backend From e21790da60be265bce2efd8f82276e405803440f Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 12:40:05 -0500 Subject: [PATCH 100/110] Fix arrays of data setting --- bilby/gw/detector/interferometer.py | 28 +++++++++---- bilby/gw/detector/strain_data.py | 63 +++++++++++++++++++++-------- bilby/gw/likelihood/roq.py | 2 +- test/gw/likelihood_test.py | 2 +- 4 files changed, 69 insertions(+), 26 deletions(-) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index bab36f05a..4fcfe98ec 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -115,16 +115,19 @@ def __repr__(self): float(self.geometry.yarm_azimuth), float(self.geometry.xarm_tilt), float(self.geometry.yarm_tilt)) - def set_strain_data_from_gwpy_timeseries(self, time_series): + def set_strain_data_from_gwpy_timeseries(self, time_series, *, xp=None): """ Set the `Interferometer.strain_data` from a gwpy TimeSeries Parameters ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to set. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_gwpy_timeseries(time_series=time_series) + self.strain_data.set_from_gwpy_timeseries(time_series=time_series, xp=xp) def set_strain_data_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -175,7 +178,7 @@ def set_strain_data_from_power_spectral_density( def set_strain_data_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `Interferometer.strain_data` from a frame file Parameters @@ -193,15 +196,18 @@ def set_strain_data_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_frame_file( frame_file=frame_file, sampling_frequency=sampling_frequency, duration=duration, start_time=start_time, - channel=channel, buffer_time=buffer_time) + channel=channel, buffer_time=buffer_time, xp=xp) def set_strain_data_from_channel_name( - self, channel, sampling_frequency, duration, start_time=0): + self, channel, sampling_frequency, duration, start_time=0, *, xp=None): """ Set the `Interferometer.strain_data` by fetching from given channel using strain_data.set_from_channel_name() @@ -216,22 +222,28 @@ def set_strain_data_from_channel_name( The data duration (in s) start_time: float The GPS start-time of the data + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ self.strain_data.set_from_channel_name( channel=channel, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + duration=duration, start_time=start_time, xp=xp) - def set_strain_data_from_csv(self, filename): + def set_strain_data_from_csv(self, filename, *, xp=None): """ Set the `Interferometer.strain_data` from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ - self.strain_data.set_from_csv(filename) + self.strain_data.set_from_csv(filename, xp=xp) def set_strain_data_from_zero_noise( self, sampling_frequency, duration, start_time=0): diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index bca7acced..017d2ea50 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -1,5 +1,7 @@ +import array_api_compat as aac import numpy as np +from ...compat.utils import array_module from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries from ...core.utils import logger, PropertyAccessor @@ -498,7 +500,7 @@ def set_from_time_domain_strain( else: raise ValueError("Data times do not match time array") - def set_from_gwpy_timeseries(self, time_series): + def set_from_gwpy_timeseries(self, time_series, *, xp=np): """ Set the strain data from a gwpy TimeSeries This sets the time_domain_strain attribute, the frequency_domain_strain @@ -509,17 +511,23 @@ def set_from_gwpy_timeseries(self, time_series): ========== time_series: gwpy.timeseries.timeseries.TimeSeries The data to use + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries logger.debug('Setting data using provided gwpy TimeSeries object') if not isinstance(time_series, TimeSeries): raise ValueError("Input time_series is not a gwpy TimeSeries") + duration = xp.asarray(time_series.duration.value) + sampling_frequency = xp.asarray(time_series.sample_rate.value) + start_time = xp.asarray(time_series.epoch.value) self._times_and_frequencies = \ - CoupledTimeAndFrequencySeries(duration=time_series.duration.value, - sampling_frequency=time_series.sample_rate.value, - start_time=time_series.epoch.value) - self._time_domain_strain = time_series.value + CoupledTimeAndFrequencySeries(duration=duration, + sampling_frequency=sampling_frequency, + start_time=start_time) + self._time_domain_strain = xp.asarray(time_series.value) self._frequency_domain_strain = None self._channel = time_series.channel @@ -529,7 +537,7 @@ def channel(self): def set_from_open_data( self, name, start_time, duration=4, outdir='outdir', cache=True, - **kwargs): + *, xp=None, **kwargs): """ Set the strain data from open LOSC data This sets the time_domain_strain attribute, the frequency_domain_strain @@ -548,30 +556,38 @@ def set_from_open_data( Directory where the psd files are saved cache: bool, optional Whether or not to store/use the acquired data. + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. **kwargs: All keyword arguments are passed to `gwpy.timeseries.TimeSeries.fetch_open_data()`. """ - timeseries = gwutils.get_open_strain_data( - name, start_time, start_time + duration, outdir=outdir, cache=cache, + name, float(start_time), float(start_time + duration), outdir=outdir, cache=cache, **kwargs) - self.set_from_gwpy_timeseries(timeseries) + if xp is None: + xp = array_module((duration, start_time)) + + self.set_from_gwpy_timeseries(timeseries, xp=xp) - def set_from_csv(self, filename): + def set_from_csv(self, filename, xp=None): """ Set the strain data from a csv file Parameters ========== filename: str The path to the file to read in + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified :code:`numpy` will be used. """ from gwpy.timeseries import TimeSeries timeseries = TimeSeries.read(filename, format='csv') - self.set_from_gwpy_timeseries(timeseries) + self.set_from_gwpy_timeseries(timeseries, xp=xp) def set_from_frequency_domain_strain( self, frequency_domain_strain, sampling_frequency=None, @@ -661,12 +677,13 @@ def set_from_zero_noise(self, sampling_frequency, duration, start_time=0): sampling_frequency=sampling_frequency, start_time=start_time) logger.debug('Setting zero noise data') - self._frequency_domain_strain = np.zeros_like(self.frequency_array, + xp = aac.get_namespace(self.frequency_array) + self._frequency_domain_strain = xp.zeros_like(self.frequency_array, dtype=complex) def set_from_frame_file( self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): + channel=None, buffer_time=1, *, xp=None): """ Set the `frequency_domain_strain` from a frame fiile Parameters @@ -684,6 +701,10 @@ def set_from_frame_file( buffer_time: float Read in data with `start_time-buffer_time` and `start_time+duration+buffer_time` + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ @@ -697,9 +718,12 @@ def set_from_frame_file( buffer_time=buffer_time, channel=channel, resample=sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) - def set_from_channel_name(self, channel, duration, start_time, sampling_frequency): + self.set_from_gwpy_timeseries(strain, xp=xp) + + def set_from_channel_name(self, channel, duration, start_time, sampling_frequency, *, xp=None): """ Set the `frequency_domain_strain` by fetching from given channel using gwpy.TimesSeries.get(), which dynamically accesses either frames on disk, or a remote NDS2 server to find and return data. This function @@ -715,6 +739,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc The GPS start-time of the data sampling_frequency: float The sampling frequency (in Hz) + xp: array module, optional + The array module to use, e.g., :code:`numpy` or :code:`jax.numpy`. + If not specified, it will be inferred from the provided duration/ + sampling frequency. """ from gwpy.timeseries import TimeSeries @@ -730,7 +758,10 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc strain = TimeSeries.get(channel, start_time, start_time + duration) strain = strain.resample(sampling_frequency) - self.set_from_gwpy_timeseries(strain) + if xp is None: + xp = aac.get_namespace(self.frequency_array) + + self.set_from_gwpy_timeseries(strain, xp=xp) class Notch(object): diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index ff691e1b8..3e8ae15d2 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -546,7 +546,7 @@ def _closest_time_indices(time, samples): closest = xp.floor((time - samples[0]) / (samples[1] - samples[0])) indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) - return xp.asarray(indices).astype(int), in_bounds + return xp.astype(indices, int), in_bounds @staticmethod def _interp_five_samples(time_samples, values, time): diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 9fc7eb9ab..e912bc354 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -95,7 +95,7 @@ def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" nll = self.likelihood.noise_log_likelihood() self.assertAlmostEqual( - -4014.1787704539474, nll, 3 + -4014.1787704539474, float(nll), 3 ) self.assertEqual(aac.get_namespace(nll), self.xp) From 721a03326c74561fd49d48185e20d82bee7ed3bc Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 12:58:03 -0500 Subject: [PATCH 101/110] BUG: fix some more roq array issues --- bilby/gw/detector/interferometer.py | 2 +- bilby/gw/likelihood/roq.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 4fcfe98ec..bf1543f0b 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -363,7 +363,7 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= signal_ifo = signal_ifo * xp.exp(-1j * 2 * np.pi * dt * frequencies) signal_ifo *= self.calibration_model.get_calibration_factor( - frequencies, prefix='recalib_{}_'.format(self.name), **parameters + frequencies, prefix=f'recalib_{self.name}_', xp=xp, **parameters ) return signal_ifo diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 3e8ae15d2..942930459 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -472,13 +472,13 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr xp = array_module(h_linear) calib_factor = interferometer.calibration_model.get_calibration_factor( - frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **parameters) + xp.asarray(frequency_nodes), prefix=f'recalib_{interferometer.name}_', xp=xp, **parameters) h_linear *= calib_factor[linear_indices] h_quadratic *= calib_factor[quadratic_indices] optimal_snr_squared = xp.vdot( xp.abs(h_quadratic)**2, - self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] + xp.asarray(self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic]) ) dt = interferometer.time_delay_from_geocenter( @@ -487,7 +487,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ifo_time = dt_geocent + dt indices, in_bounds = self._closest_time_indices( - ifo_time, self.weights['time_samples']) + ifo_time, xp.asarray(self.weights['time_samples'])) indices = xp.clip(indices, 0, len(self.weights['time_samples']) - 1) d_inner_h_tc_array = xp.einsum( 'i,ji->j', From 78e94da369aeecd5db2d62a571ef7259bfa51841 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 14:37:39 -0500 Subject: [PATCH 102/110] Make ROQ calculations use correct array backend --- bilby/gw/likelihood/roq.py | 165 ++++++++++++++++++++--------------- test/gw/likelihood_test.py | 170 +++++++++++++++---------------------- 2 files changed, 165 insertions(+), 170 deletions(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 942930459..c52a32a59 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -1,5 +1,5 @@ - import array_api_compat as aac +import array_api_extra as xpx import numpy as np from .base import GravitationalWaveTransient @@ -273,15 +273,16 @@ def _set_unique_frequency_nodes_and_inverse(self): """Set unique frequency nodes and indices to recover linear and quadratic frequency nodes for each combination of linear and quadratic bases """ + xp = aac.array_namespace(self.interferometers.frequency_array) self._unique_frequency_nodes_and_inverse = [] for idx_linear in range(self.number_of_bases_linear): tmp = [] - frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear] + frequency_nodes_linear = xp.asarray(self.weights['frequency_nodes_linear'][idx_linear]) size_linear = len(frequency_nodes_linear) for idx_quadratic in range(self.number_of_bases_quadratic): - frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic] - frequency_nodes_unique, original_indices = np.unique( - np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), + frequency_nodes_quadratic = xp.asarray(self.weights['frequency_nodes_quadratic'][idx_quadratic]) + frequency_nodes_unique, original_indices = xp.unique( + xp.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), return_inverse=True ) linear_indices = original_indices[:size_linear] @@ -488,7 +489,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr indices, in_bounds = self._closest_time_indices( ifo_time, xp.asarray(self.weights['time_samples'])) - indices = xp.clip(indices, 0, len(self.weights['time_samples']) - 1) + indices = xp.clip(xp.asarray(indices), 0, len(self.weights['time_samples']) - 1) d_inner_h_tc_array = xp.einsum( 'i,ji->j', xp.conj(h_linear), @@ -543,10 +544,10 @@ def _closest_time_indices(time, samples): Whether the indices are for valid times """ xp = array_module(time) - closest = xp.floor((time - samples[0]) / (samples[1] - samples[0])) + closest = xp.astype(xp.floor((time - samples[0]) / (samples[1] - samples[0])), int) indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] - in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) - return xp.astype(indices, int), in_bounds + in_bounds = (indices[0] >= 0) & (indices[-1] < len(samples)) + return indices, in_bounds @staticmethod def _interp_five_samples(time_samples, values, time): @@ -568,16 +569,15 @@ def _interp_five_samples(time_samples, values, time): value: float The value of the function at the input time """ - xp = aac.get_namespace(time_samples) r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. r2 = values[2] - 2. * values[3] + values[4] - a = (time_samples[3] - time) / xp.maximum(time_samples[1] - time_samples[0], 1e-12) + a = (time_samples[3] - time) / max(time_samples[1] - time_samples[0], 1e-12) b = 1. - a c = (a**3. - a) / 6. d = (b**3. - b) / 6. return a * values[2] + b * values[3] + c * r1 + d * r2 - def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): + def _calculate_d_inner_h_array(self, times, h_linear, ifo_name, *, xp=None): """ Calculate d_inner_h at regularly-spaced time samples. Each value is interpolated from the nearest 5 samples with the algorithm explained in @@ -595,21 +595,23 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): ======= d_inner_h_array: array-like """ + if xp is None: + xp = aac.array_namespace(h_linear) roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space - closest_idxs = np.floor(times_per_roq_time_space).astype(int) + closest_idxs = xp.astype(xp.floor(times_per_roq_time_space), int) # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time # spacing is larger than 5 times the ROQ time spacing. weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] - h_linear_conj = np.conj(h_linear) + h_linear_conj = h_linear.conj() if (times[1] - times[0]) / roq_time_space > 5: - d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) - d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) - d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj) - d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj) - d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj) + d_inner_h_m2 = weights_linear[closest_idxs - 2] @ h_linear_conj + d_inner_h_m1 = weights_linear[closest_idxs - 1] @ h_linear_conj + d_inner_h_0 = weights_linear[closest_idxs] @ h_linear_conj + d_inner_h_p1 = weights_linear[closest_idxs + 1] @ h_linear_conj + d_inner_h_p2 = weights_linear[closest_idxs + 2] @ h_linear_conj else: - d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj) + d_inner_h_at_roq_time_samples = weights_linear @ h_linear_conj d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2] d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1] d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs] @@ -717,6 +719,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): linear and quadratic basis """ + xp = aac.array_namespace(self.interferometers.frequency_array) time_space = self._get_time_resolution() number_of_time_samples = int(self.interferometers.duration / time_space) earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space @@ -736,7 +739,7 @@ def _set_weights(self, linear_matrix, quadratic_matrix): - self.interferometers.start_time ) / time_space)) ) - self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * float(time_space) + self.weights['time_samples'] = xp.arange(start_idx, end_idx + 1) * float(time_space) logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) # select bases to be used, set prior ranges and frequency nodes if exist @@ -789,10 +792,10 @@ def _set_weights(self, linear_matrix, quadratic_matrix): roq_mask = roq_frequencies >= roq_scaled_minimum_frequency roq_frequencies = roq_frequencies[roq_mask] overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d( - ifo.frequency_array[ifo.frequency_mask], roq_frequencies, + np.asarray(ifo.frequency_array[ifo.frequency_mask]), roq_frequencies, return_indices=True) else: - overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] + overlap_frequencies = np.asarray(ifo.frequency_array[ifo.frequency_mask]) roq_idxs_this_ifo = np.arange( linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1], dtype=int) @@ -848,31 +851,43 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): data_over_psd = {} for ifo in self.interferometers: nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int( - ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration) - data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \ - ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] - try: - import pyfftw - ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') - except ImportError: + ifo.minimum_frequency * self.interferometers.duration) + data_over_psd[ifo.name] = ( + ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] + / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] + ) + xp = array_module(data_over_psd) + if aac.is_numpy_namespace(xp): + try: + import pyfftw + ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) + ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') + except ImportError: + pyfftw = None + logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") + ifft_input = np.zeros(number_of_time_samples, dtype=complex) + ifft = np.fft.ifft + else: pyfftw = None - logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") - ifft_input = np.zeros(number_of_time_samples, dtype=complex) - ifft = np.fft.ifft + ifft_input = xp.zeros(number_of_time_samples, dtype=complex) + ifft = xp.fft.ifft for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis']) basis_size = linear_matrix_single.shape[0] for ifo in self.interferometers: - ifft_input[:] *= 0. + if pyfftw: + ifft_input[:] *= 0. + else: + ifft_input *= 0 linear_weights = \ - np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex) + xp.zeros((basis_size, len(self.weights['time_samples'])), dtype=complex) for i in range(basis_size): - basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] - ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) - linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] + basis_element = xp.asarray(linear_matrix_single[i][roq_idxs[ifo.name]]).conj() + ifft_input = xpx.at(ifft_input, nonzero_idxs[ifo.name]).set(data_over_psd[ifo.name] * basis_element) + linear_weights = xpx.at(linear_weights, i).set(ifft(ifft_input)[start_idx:end_idx + 1]) + linear_weights = linear_weights.T linear_weights *= 4. * number_of_time_samples / float(self.interferometers.duration) self.weights[ifo.name + '_linear'].append(linear_weights) if pyfftw is not None: @@ -892,6 +907,7 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_linear'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -899,29 +915,36 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): # prepare time-shifted data, which is multiplied by basis tc_shifted_data = dict() for ifo in self.interferometers: - over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) - over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \ - ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] - over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data) - tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) + over_whitened_frequency_data = xp.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) + over_whitened_frequency_data = xpx.at( + over_whitened_frequency_data, xp.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask] + ).set(ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]) + over_whitened_time_data = xp.fft.irfft(over_whitened_frequency_data) + tc_shifted_data[ifo.name] = xp.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] - fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb - Db = np.fft.rfft( + fs = xp.arange(start_frequency_bin, end_frequency_bin + 1) / Tb + Db = xp.fft.rfft( over_whitened_time_data[-int(2. * fhigh_basis * Tb):] )[start_frequency_bin:end_frequency_bin + 1] start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( - 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - float(ifo.duration) + Tb)) + this_data = xp.zeros(len(self.weights['time_samples']), dtype=complex) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = ( + 4. / Tb * Db[:, None] * xp.exp( + 2. * np.pi * 1j * fs[:, None] * (xp.asarray(self.weights['time_samples'][None, :]) - ifo.duration + Tb) + ) + ) + tc_shifted_data[ifo.name] = xpx.at(tc_shifted_data[ifo.name], sl).set(this_data) + start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()] + linear_matrix_single = xp.asarray(linear_matrix['basis_linear'][str(basis_idx)]['basis'][()]) for ifo in self.interferometers: - self.weights[ifo.name + '_linear'].append( - np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T) + self.weights[ifo.name + '_linear'].append((linear_matrix_single.conj() @ tc_shifted_data[ifo.name]).T) def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs): """ @@ -943,14 +966,15 @@ def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idx """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: + inv_psd = xp.asarray(1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]) self.weights[ifo.name + '_quadratic'].append( - 4. / ifo.strain_data.duration * np.dot( - quadratic_matrix_single[:, roq_idxs[ifo.name]], - 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]])) + 4. / ifo.strain_data.duration * quadratic_matrix_single[:, roq_idxs[ifo.name]] @ inv_psd + ) del quadratic_matrix_single def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): @@ -967,6 +991,7 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): """ for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'] = [] + xp = aac.array_namespace(self.interferometers.frequency_array) Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) @@ -974,27 +999,31 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): # prepare coefficients multiplied by basis multibanded_inverse_psd = dict() for ifo in self.interferometers: - inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1) - inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \ - 1. / ifo.power_spectral_density_array[ifo.frequency_mask] - inverse_psd_time = np.fft.irfft(inverse_psd_frequency) - multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension) + inverse_psd_frequency = xp.zeros(int(fhigh_basis * ifo.duration) + 1) + sl = np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask] + inverse_psd_frequency = xpx.at(inverse_psd_frequency, sl).set( + 1. / xp.asarray(ifo.power_spectral_density_array[ifo.frequency_mask]) + ) + inverse_psd_time = xp.fft.irfft(inverse_psd_frequency) + multibanded_inverse_psd[ifo.name] = xp.zeros(basis_dimension) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] number_of_samples_half = int(fhigh_basis * Tb) start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft( - np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) + sl = slice(start_idx_of_band, start_idx_of_next_band) + this_data = 4. / Tb * xp.fft.rfft( + xp.concat([inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]]) )[start_frequency_bin:end_frequency_bin + 1].real + multibanded_inverse_psd[ifo.name] = xpx.at(multibanded_inverse_psd[ifo.name], sl).set(this_data) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = xp.asarray(quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real) for ifo in self.interferometers: self.weights[ifo.name + '_quadratic'].append( - np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name])) + quadratic_matrix_single @ multibanded_inverse_psd[ifo.name]) def save_weights(self, filename, format='hdf5'): """ @@ -1209,8 +1238,8 @@ def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations times = self._times if self.jitter_time: times = times + parameters["time_jitter"] - time_prior_array = self.priors['geocent_time'].prob(times) - time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array + time_prior_array = np.asarray(self.priors['geocent_time'].prob(times)) + time_post = np.exp(np.asarray(time_log_like - max(time_log_like))) * time_prior_array time_post /= np.sum(time_post) return random.rng.choice(times, p=time_post) diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index e912bc354..1109b6692 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -16,7 +16,13 @@ class BackendWaveformGenerator(bilby.gw.waveform_generator.WaveformGenerator): - """A thin wrapper to emulate different backends in the waveform generator.""" + """ + A thin wrapper to emulate different backends in the waveform generator. + + This ensures that all frequency arrays that might be used inside the + source are cast to numpy for compatibility. The outputs are converted + to the appropriate array type. + """ def __init__(self, wfg, xp): self.wfg = wfg self.xp = xp @@ -35,12 +41,15 @@ def convert_nested_dict(self, data): raise ValueError("Input must be an array API object or a dict of such objects.") def _strain_from_model(self, model_data_points, model, parameters): - # we can't pass a frequency array through as a torch array model_data_points = np.asarray(model_data_points) return super()._strain_from_model(model_data_points, model, parameters) def frequency_domain_strain(self, parameters): self.wfg.frequency_array = np.asarray(self.wfg.frequency_array) + if "frequency_nodes" in self.wfg.waveform_arguments: + self.wfg.waveform_arguments["frequency_nodes"] = np.asarray( + self.wfg.waveform_arguments["frequency_nodes"] + ) wf = self.wfg.__class__.frequency_domain_strain(self, parameters) return self.convert_nested_dict(wf) @@ -335,13 +344,31 @@ def test_time_reference_agrees_with_default(self): ) +class ROQBasisMixin: + + @property + def roq_dir(self): + trial_roq_paths = [ + "/roq_basis", + os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), + "/home/cbc/ROQ_data/IMRPhenomPv2/4s", + ] + if "BILBY_TESTING_ROQ_DIR" in os.environ: + trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) + for path in trial_roq_paths: + if os.path.isdir(path): + return path + raise Exception("Unable to load ROQ basis: cannot proceed with tests") + @pytest.mark.requires_roqs @pytest.mark.array_backend @pytest.mark.usefixtures("xp_class") -class TestROQLikelihood(unittest.TestCase): +@pytest.mark.flaky(reruns=3) # pyfftw is flake on some machines +class TestROQLikelihood(ROQBasisMixin, unittest.TestCase): def setUp(self): self.duration = self.xp.asarray(4.0) self.sampling_frequency = self.xp.asarray(2048.0) + bilby.core.utils.random.seed(500) self.test_parameters = dict( mass_1=36.0, @@ -412,22 +439,6 @@ def tearDown(self): self.priors, ) - @property - def roq_dir(self): - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - if "BILBY_TESTING_ROQ_DIR" in os.environ: - trial_roq_paths.insert(0, os.environ["BILBY_TESTING_ROQ_DIR"]) - print(trial_roq_paths) - for path in trial_roq_paths: - print(path, os.path.isdir(path)) - if os.path.isdir(path): - return path - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - @property def linear_matrix_file(self): return f"{self.roq_dir}/B_linear.npy" @@ -631,33 +642,18 @@ def test_create_roq_weights_fails_due_to_duration(self): @pytest.mark.requires_roqs -class TestRescaledROQLikelihood(unittest.TestCase): +class TestRescaledROQLikelihood(unittest.TestCase, ROQBasisMixin): def test_rescaling(self): + linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" - # Possible locations for the ROQ: in the docker image, local, or on CIT - trial_roq_paths = [ - "/roq_basis", - os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"), - "/home/cbc/ROQ_data/IMRPhenomPv2/4s", - ] - roq_dir = None - for path in trial_roq_paths: - if os.path.isdir(path): - roq_dir = path - break - if roq_dir is None: - raise Exception("Unable to load ROQ basis: cannot proceed with tests") - - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{self.roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{self.roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) + self.linear_matrix_file = f"{self.roq_dir}/B_linear.npy" + self.quadratic_matrix_file = f"{self.roq_dir}/B_quadratic.npy" + self.params_file = f"{self.roq_dir}/params.dat" scale_factor = 0.5 params = np.genfromtxt(self.params_file, names=True) @@ -707,7 +703,7 @@ def test_rescaling(self): @pytest.mark.requires_roqs @pytest.mark.array_backend @pytest.mark.usefixtures("xp_class") -class TestROQLikelihoodHDF5(unittest.TestCase): +class TestROQLikelihoodHDF5(unittest.TestCase, ROQBasisMixin): """ Test ROQ likelihood constructed from .hdf5 basis @@ -715,9 +711,8 @@ class TestROQLikelihoodHDF5(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun= self.priors["chirp_mass"].minimum) * @@ -1006,8 +1001,8 @@ def assertLess_likelihood_errors( interferometers=interferometers, priors=self.priors, waveform_generator=search_waveform_generator, - linear_matrix=basis_linear, - quadratic_matrix=basis_quadratic, + linear_matrix=f"{self.roq_dir}/{basis_linear}", + quadratic_matrix=f"{self.roq_dir}/{basis_quadratic}", roq_scale_factor=roq_scale_factor ) for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11): @@ -1025,7 +1020,7 @@ def assertLess_likelihood_errors( @pytest.mark.requires_roqs -class TestCreateROQLikelihood(unittest.TestCase): +class TestCreateROQLikelihood(unittest.TestCase, ROQBasisMixin): """ Test if ROQ likelihood is constructed without any errors from .hdf5 or .npy basis @@ -1033,9 +1028,8 @@ class TestCreateROQLikelihood(unittest.TestCase): respectively, and 2 quadratic bases constructed over 8Msun Date: Tue, 3 Feb 2026 15:22:43 -0500 Subject: [PATCH 103/110] BUG: fix a missing array case --- bilby/core/prior/analytical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 02e65bf3a..69afd5926 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1011,7 +1011,7 @@ def ln_prob(self, val, *, xp=None): ======= Union[float, array_like]: Prior probability of val """ - ln_prob = xlog1py(self.beta - 1.0, -val) + xlogy(self.alpha - 1.0, val) + ln_prob = xlog1py(self.beta - 1.0, -val) + xlogy(xp.asarray(self.alpha - 1.0), val) ln_prob -= betaln(self.alpha, self.beta) return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) From 8159bb8f7eefb7b59e3651db30efb32d0e770b32 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 15:34:45 -0500 Subject: [PATCH 104/110] FMT: pre-commit fixes --- bilby/gw/likelihood/roq.py | 9 ++++++--- test/gw/likelihood_test.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index c52a32a59..d2564ce94 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -918,7 +918,9 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): over_whitened_frequency_data = xp.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) over_whitened_frequency_data = xpx.at( over_whitened_frequency_data, xp.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask] - ).set(ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]) + ).set( + ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) over_whitened_time_data = xp.fft.irfft(over_whitened_frequency_data) tc_shifted_data[ifo.name] = xp.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) start_idx_of_band = 0 @@ -933,7 +935,8 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): sl = slice(start_idx_of_band, start_idx_of_next_band) this_data = ( 4. / Tb * Db[:, None] * xp.exp( - 2. * np.pi * 1j * fs[:, None] * (xp.asarray(self.weights['time_samples'][None, :]) - ifo.duration + Tb) + 2. * np.pi * 1j * fs[:, None] + * (xp.asarray(self.weights['time_samples'][None, :]) - ifo.duration + Tb) ) ) tc_shifted_data[ifo.name] = xpx.at(tc_shifted_data[ifo.name], sl).set(this_data) @@ -973,7 +976,7 @@ def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idx for ifo in self.interferometers: inv_psd = xp.asarray(1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]) self.weights[ifo.name + '_quadratic'].append( - 4. / ifo.strain_data.duration * quadratic_matrix_single[:, roq_idxs[ifo.name]] @ inv_psd + 4. / ifo.strain_data.duration * quadratic_matrix_single[:, roq_idxs[ifo.name]] @ inv_psd ) del quadratic_matrix_single diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 1109b6692..40be9aaa7 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -360,6 +360,7 @@ def roq_dir(self): return path raise Exception("Unable to load ROQ basis: cannot proceed with tests") + @pytest.mark.requires_roqs @pytest.mark.array_backend @pytest.mark.usefixtures("xp_class") From 438ed6438dae300c399245595152ab3e2bf6d8bc Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 15:54:29 -0500 Subject: [PATCH 105/110] CI: drop torch tests for python 3.10 --- .github/workflows/unit-tests.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index a0b81d424..f7265fe5f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -57,9 +57,6 @@ jobs: # - name: Run precommits # run: | # pre-commit run --all-files --verbose --show-diff-on-failure - - name: Run torch-backend unit tests - run: | - SCIPY_ARRAY_API=1 pytest --array-backend torch --durations 10 - name: Run unit tests run: | python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml @@ -67,6 +64,11 @@ jobs: run: | python -m pip install .[jax] SCIPY_ARRAY_API=1 pytest --array-backend jax --durations 10 + - name: Run torch-backend unit tests + # there are scipy version issues with python 3.10 and torch + if: matrix.python.version > 3.10 + run: | + SCIPY_ARRAY_API=1 pytest --array-backend torch --durations 10 - name: Run sampler tests run: | pytest test/integration/sampler_run_test.py --durations 10 -v From 00bee6da4e7e5f6d205851d81f2e1a322bc05e90 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 16:26:40 -0500 Subject: [PATCH 106/110] FMT: precommit fix --- bilby/gw/likelihood/roq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index d2564ce94..a73d99b5a 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -940,7 +940,7 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): ) ) tc_shifted_data[ifo.name] = xpx.at(tc_shifted_data[ifo.name], sl).set(this_data) - + start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: From f152742350d4335288ce4731c6b2d54b571ab409 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 16:40:28 -0500 Subject: [PATCH 107/110] TEST: exclude studentt tests for jax --- test/core/prior/prior_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index a3165adce..67fbd2422 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -259,6 +259,11 @@ def condition_func(reference_params, test_param): p for p in self.priors if not isinstance(p, bilby.core.prior.Interped) ] + elif aac.is_jax_namespace(self.xp): + self.priors = [ + p for p in self.priors + if not isinstance(p, bilby.core.prior.StudentT) + ] def tearDown(self): del self.priors @@ -798,7 +803,6 @@ def test_accuracy(self): bilby.core.prior.WeightedDiscreteValues, ) if isinstance(prior, (testTuple)): - print(prior) np.testing.assert_almost_equal(prior.prob(self.xp.asarray(domain)), scipy_prob) np.testing.assert_almost_equal(prior.ln_prob(self.xp.asarray(domain)), scipy_lnprob) np.testing.assert_almost_equal(prior.cdf(self.xp.asarray(domain)), scipy_cdf) From 4cfa4fb9405df768fe35312850d1a45a239b9728 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 3 Feb 2026 17:58:10 -0500 Subject: [PATCH 108/110] Add some more explicit array casts --- bilby/core/prior/analytical.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 69afd5926..bec049d90 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -981,7 +981,8 @@ def rescale(self, val, *, xp=None): This explicitly casts to the requested backend, but the computation will be done by scipy. """ return ( - xp.asarray(betaincinv(self.alpha, self.beta, val)) * (self.maximum - self.minimum) + xp.asarray(betaincinv(xp.asarray(self.alpha), xp.asarray(self.beta), val)) + * (self.maximum - self.minimum) + self.minimum ) @@ -1011,14 +1012,18 @@ def ln_prob(self, val, *, xp=None): ======= Union[float, array_like]: Prior probability of val """ - ln_prob = xlog1py(self.beta - 1.0, -val) + xlogy(xp.asarray(self.alpha - 1.0), val) - ln_prob -= betaln(self.alpha, self.beta) + ln_prob = xlog1py(xp.asarray(self.beta - 1.0), -val) + xlogy(xp.asarray(self.alpha - 1.0), val) + ln_prob -= betaln(xp.asarray(self.alpha), xp.asarray(self.beta)) return xp.nan_to_num(ln_prob, nan=-xp.inf, neginf=-xp.inf, posinf=-xp.inf) @xp_wrap def cdf(self, val, *, xp=None): return xp.nan_to_num( - betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum)) + betainc( + xp.asarray(self.alpha), + xp.asarray(self.beta), + (val - self.minimum) / (self.maximum - self.minimum) + ) ) + (val > self.maximum) From b638d8e064d8d9dd5447adebd27146729b103113 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 17 Feb 2026 09:44:41 -0500 Subject: [PATCH 109/110] BUG: bug fixes for prior and gw likelihoods --- bilby/core/prior/dict.py | 27 +++++++++----------------- bilby/core/utils/calculus.py | 4 ++-- bilby/gw/likelihood/base.py | 37 +++++++++++++++++++----------------- 3 files changed, 31 insertions(+), 37 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index dd8e586e6..c3e61d569 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -5,7 +5,7 @@ from io import open as ioopen from warnings import warn -import array_api_extra as xpx +import array_api_compat as aac import numpy as np from .analytical import DeltaFunction @@ -467,7 +467,7 @@ def check_efficiency(n_tested, n_valid): sample = self.sample_subset(keys=keys, size=size, xp=xp) is_valid = self.evaluate_constraints(sample) n_tested_samples += 1 - n_valid_samples += int(is_valid) + n_valid_samples += int(is_valid.item()) check_efficiency(n_tested_samples, n_valid_samples) if is_valid: return sample @@ -554,7 +554,7 @@ def prob(self, sample, *, xp=None, **kwargs): @xp_wrap def check_prob(self, sample, prob, *, xp=None): ratio = self.normalize_constraint_factor(tuple(sample.keys())) - if xp.all(prob == 0.0): + if not aac.is_jax_namespace(xp) and xp.all(prob == 0.0): return prob * ratio else: if isinstance(prob, float): @@ -563,11 +563,8 @@ def check_prob(self, sample, prob, *, xp=None): else: return 0.0 else: - constrained_prob = xp.zeros_like(prob) - in_bounds = xp.isfinite(prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = self.evaluate_constraints(subsample, xp=xp) - constrained_prob = xpx.at(constrained_prob, in_bounds).set(prob[in_bounds] * keep * ratio) + keep = self.evaluate_constraints(sample, xp=xp) + constrained_prob = xp.where(keep, prob * ratio, 0.0) return constrained_prob @xp_wrap @@ -591,8 +588,7 @@ def ln_prob(self, sample, axis=None, normalized=True, *, xp=None): """ ln_prob = xp.sum(xp.stack([self[key].ln_prob(sample[key], xp=xp) for key in sample]), axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized, xp=xp) + return self.check_ln_prob(sample, ln_prob, normalized=normalized, xp=xp) @xp_wrap def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): @@ -600,7 +596,7 @@ def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): ratio = self.normalize_constraint_factor(tuple(sample.keys())) else: ratio = 1 - if xp.all(xp.isfinite(ln_prob)): + if not aac.is_jax_namespace(xp) and xp.all(xp.isfinite(ln_prob)): return ln_prob else: if isinstance(ln_prob, float): @@ -609,13 +605,8 @@ def check_ln_prob(self, sample, ln_prob, normalized=True, *, xp=None): else: return -np.inf else: - constrained_ln_prob = -np.inf * xp.ones_like(ln_prob) - in_bounds = xp.isfinite(ln_prob) - subsample = {key: sample[key][in_bounds] for key in sample} - keep = xp.log(self.evaluate_constraints(subsample, xp=xp)) - constrained_ln_prob = xpx.at(constrained_ln_prob, in_bounds).set( - ln_prob[in_bounds] + keep + xp.log(ratio) - ) + keep = self.evaluate_constraints(sample, xp=xp) + constrained_ln_prob = xp.where(keep, ln_prob + xp.log(ratio), -xp.inf) return constrained_ln_prob @xp_wrap diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index 496946c2b..f97ee1a9b 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -263,8 +263,8 @@ def _call_jax(self, x, y): from interpax import interp2d return interp2d( - x, - y, + jnp.asarray(x), + jnp.asarray(y), jnp.asarray(self.x), jnp.asarray(self.y), jnp.asarray(self.z), diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index bc8296915..5a7e957a7 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -2,6 +2,7 @@ import os import copy +import array_api_compat as aac import attr import numpy as np from scipy.special import logsumexp @@ -297,32 +298,33 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr optimal_snr_squared_array = None normalization = 4 / self.waveform_generator.duration + xp = signal.__array_namespace__() if return_array is False: d_inner_h_array = None optimal_snr_squared_array = None elif self.time_marginalization and self.calibration_marginalization: - d_inner_h_integrand = np.tile( + d_inner_h_integrand = xp.tile( interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( + d_inner_h_array = 4 / self.waveform_generator.duration * xp.fft.fft( d_inner_h_integrand[0:-1], axis=0 ).T optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) elif self.time_marginalization and not self.calibration_marginalization: - d_inner_h_array = normalization * np.fft.fft( + d_inner_h_array = normalization * xp.fft.fft( signal[0:-1] * interferometer.frequency_domain_strain.conj()[0:-1] / interferometer.power_spectral_density_array[0:-1] @@ -334,12 +336,12 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr interferometer.frequency_domain_strain.conj() * signal / interferometer.power_spectral_density_array ) - d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) + d_inner_h_array = xp.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * xp.abs(signal)**2 / interferometer.power_spectral_density_array ) - optimal_snr_squared_array = np.dot( + optimal_snr_squared_array = xp.dot( optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) @@ -802,14 +804,15 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters if self.jitter_time: times = self._times + parameters['time_jitter'] - _time_prior = self.priors['geocent_time'] - time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) - times = times[time_mask] + if not aac.is_jax_array(d_inner_h_tc_array): + _time_prior = self.priors['geocent_time'] + time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) + times = times[time_mask] + if self.calibration_marginalization: + d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] + else: + d_inner_h_tc_array = d_inner_h_tc_array[time_mask] time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc - if self.calibration_marginalization: - d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] - else: - d_inner_h_tc_array = d_inner_h_tc_array[time_mask] if self.distance_marginalization: log_l_tc_array = self.distance_marginalized_likelihood( @@ -819,9 +822,9 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h) elif self.calibration_marginalization: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h[:, np.newaxis] / 2 else: - log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2 + log_l_tc_array = d_inner_h_tc_array.real - h_inner_h / 2 return logsumexp(log_l_tc_array, b=time_prior_array, axis=-1) def get_calibration_log_likelihoods(self, signal_polarizations=None, parameters=None): From 922367778f1a9fca72b44f42906b9a34f5f098aa Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Tue, 17 Feb 2026 10:18:16 -0500 Subject: [PATCH 110/110] FEAT: add precessing spin transformation directly into bilby --- bilby/core/utils/constants.py | 1 + bilby/gw/conversion.py | 28 +-- bilby/gw/geometry.py | 174 +++++++++++++++++- bilby/gw/source.py | 5 +- .../injection_examples/jax_fast_tutorial.py | 28 +-- test/gw/geometry_test.py | 107 +++++++++++ 6 files changed, 291 insertions(+), 52 deletions(-) create mode 100644 test/gw/geometry_test.py diff --git a/bilby/core/utils/constants.py b/bilby/core/utils/constants.py index 8dbec27da..e734923a5 100644 --- a/bilby/core/utils/constants.py +++ b/bilby/core/utils/constants.py @@ -5,3 +5,4 @@ solar_mass = 1.988409870698050731911960804878414216e30 # Kg radius_of_earth = 6378136.6 # m gravitational_constant = 6.6743e-11 # m^3 kg^-1 s^-2 +msun_time_si = gravitational_constant * solar_mass / speed_of_light**3 # s diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index cc1b3e493..bc2d35930 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -33,6 +33,7 @@ from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions from .eos.eos import IntegrateTOV from .cosmology import get_cosmology, z_at_value +from .geometry import transform_precessing_spins def redshift_to_luminosity_distance(redshift, cosmology=None): @@ -152,33 +153,14 @@ def bilby_to_lalsimulation_spins( spin_2z = a_2 * np.cos(tilt_2) iota = theta_jn else: - from numbers import Number - args = ( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, - mass_2, reference_frequency, phase + func = transform_precessing_spins + iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = func( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, + mass_1, mass_2, reference_frequency, phase ) - float_inputs = all([isinstance(arg, Number) for arg in args]) - if float_inputs: - func = lalsim_SimInspiralTransformPrecessingNewInitialConditions - else: - func = transform_precessing_spins - iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = func(*args) return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z -@np.vectorize -def transform_precessing_spins(*args): - """ - Vectorized wrapper for - lalsimulation.SimInspiralTransformPrecessingNewInitialConditions - - For detailed documentation see - :code:`bilby.gw.conversion.bilby_to_lalsimulation_spins`. - This will be removed from the public API in a future release. - """ - return lalsim_SimInspiralTransformPrecessingNewInitialConditions(*args) - - def convert_to_lal_binary_black_hole_parameters(parameters): """ Convert parameters we have into parameters we need. diff --git a/bilby/gw/geometry.py b/bilby/gw/geometry.py index 68321d4b4..fabed5315 100644 --- a/bilby/gw/geometry.py +++ b/bilby/gw/geometry.py @@ -2,7 +2,7 @@ from .time import greenwich_mean_sidereal_time from ..compat.utils import array_module, promote_to_array - +from ..core.utils.constants import msun_time_si __all__ = [ "antenna_response", @@ -193,3 +193,175 @@ def zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x): theta = xp.arccos(omega[2]) phi = xp.arctan2(omega[1], omega[0]) % (2 * xp.pi) return theta, phi + + +def transform_precessing_spins( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + chi_1, + chi_2, + mass_1, + mass_2, + f_ref, + phase, +): + """ + A direct reimplementation of + :code:`lalsimulation.SimInspiralTransformPrecessingNewInitialConditions`. + + Parameters + ---------- + theta_jn: float | xp.ndarray + Zenith angle between J and N (rad). + phi_jl: float | xp.ndarray + Azimuthal angle of L_N on its cone about J (rad). + tilt_1: float | xp.ndarray + Zenith angle between S1 and LNhat (rad). + tilt_2: float | xp.ndarray + Zenith angle between S2 and LNhat (rad). + phi_12: float | xp.ndarray + Difference in azimuthal angle between S1, S2 (rad). + chi_1: float | xp.ndarray + Dimensionless spin of body 1. + chi_2: float | xp.ndarray + Dimensionless spin of body 2. + mass_1: float | xp.ndarray + Mass of body 1 (solar masses). + mass_2: float | xp.ndarray + Mass of body 2 (solar masses). + f_ref: float | xp.ndarray + Reference GW frequency (Hz). + phase: float | xp.ndarray + Reference orbital phase. + + Returns + ------- + tuple + (iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z) + - iota: Inclination angle of L_N + - spin_1x, spin_1y, spin_1z: Components of spin 1 + - spin_2x, spin_2y, spin_2z: Components of spin 2 + """ + + xp = array_module(theta_jn) + pi = xp.pi + + # Helper rotation functions + def rotate_z(angle, vec): + """Rotate vector about z-axis""" + cos_a = xp.cos(angle) + sin_a = xp.sin(angle) + x_new = cos_a * vec[0] - sin_a * vec[1] + y_new = sin_a * vec[0] + cos_a * vec[1] + return xp.stack([x_new, y_new, vec[2]], axis=0) + + def rotate_y(angle, vec): + """Rotate vector about y-axis""" + cos_a = xp.cos(angle) + sin_a = xp.sin(angle) + x_new = cos_a * vec[0] + sin_a * vec[2] + z_new = -sin_a * vec[0] + cos_a * vec[2] + return xp.stack([x_new, vec[1], z_new], axis=0) + + # Starting frame: LNhat is along the z-axis + ln_hat = xp.stack([ + xp.zeros_like(theta_jn), + xp.zeros_like(theta_jn), + xp.ones_like(theta_jn) + ], axis=0) + + # Initial spin unit vectors + s1_hat = xp.stack([ + xp.sin(tilt_1) * xp.cos(phase), + xp.sin(tilt_1) * xp.sin(phase), + xp.cos(tilt_1) + ], axis=0) + + s2_hat = xp.stack([ + xp.sin(tilt_2) * xp.cos(phi_12 + phase), + xp.sin(tilt_2) * xp.sin(phi_12 + phase), + xp.cos(tilt_2) + ], axis=0) + + # Compute physical parameters + m_total = mass_1 + mass_2 + eta = mass_1 * mass_2 / (m_total * m_total) + + # v parameter at reference point (c=G=1 units) + v0 = (m_total * msun_time_si * pi * f_ref) ** (1/3) + + # Compute angular momentum magnitude using PN expressions + # L/M = eta * v^(-1) * (1 + v^2 * L_2PN) + # L_2PN = 3/2 + 1/6 * eta + l_2pn = 1.5 + eta / 6.0 + l_mag = eta * m_total * m_total / v0 * (1.0 + v0 * v0 * l_2pn) + + # Spin vectors with proper magnitudes + s1 = mass_1 * mass_1 * chi_1 * s1_hat + s2 = mass_2 * mass_2 * chi_2 * s2_hat + + # Total angular momentum J = L + S1 + S2 + l_vec = xp.stack([xp.zeros_like(theta_jn), xp.zeros_like(theta_jn), l_mag], axis=0) + j = l_vec + s1 + s2 + + # Normalize J to get Jhat and find its angles + j_norm = xp.sqrt(xp.sum(j * j, axis=0)) + j_hat = j / j_norm + + theta_0 = xp.arccos(j_hat[2]) + phi_0 = xp.arctan2(j_hat[1], j_hat[0]) + + # Rotation 1: Rotate about z-axis by -phi_0 to put Jhat in x-z plane + angle = -phi_0 + s1_hat = rotate_z(angle, s1_hat) + s2_hat = rotate_z(angle, s2_hat) + + # Rotation 2: Rotate about y-axis by -theta_0 to put Jhat along z-axis + angle = -theta_0 + ln_hat = rotate_y(angle, ln_hat) + s1_hat = rotate_y(angle, s1_hat) + s2_hat = rotate_y(angle, s2_hat) + + # Rotation 3: Rotate about z-axis by (phi_jl - pi) to put L at desired azimuth + angle = phi_jl - pi + ln_hat = rotate_z(angle, ln_hat) + s1_hat = rotate_z(angle, s1_hat) + s2_hat = rotate_z(angle, s2_hat) + + # Compute inclination: angle between L and N + n = xp.stack([ + xp.zeros_like(theta_jn), + xp.sin(theta_jn), + xp.cos(theta_jn) + ], axis=0) + iota = xp.arccos(xp.sum(n * ln_hat, axis=0)) + + # Rotation 4-5: Bring L into the z-axis + theta_lj = xp.arccos(ln_hat[2]) + phi_l = xp.arctan2(ln_hat[1], ln_hat[0]) + + angle = -phi_l + s1_hat = rotate_z(angle, s1_hat) + s2_hat = rotate_z(angle, s2_hat) + n = rotate_z(angle, n) + + angle = -theta_lj + s1_hat = rotate_y(angle, s1_hat) + s2_hat = rotate_y(angle, s2_hat) + n = rotate_y(angle, n) + + # Rotation 6: Bring N into y-z plane with positive y component + phi_n = xp.arctan2(n[1], n[0]) + + angle = pi / 2.0 - phi_n - phase + s1_hat = rotate_z(angle, s1_hat) + s2_hat = rotate_z(angle, s2_hat) + + # Return final spin components + spin_1 = s1_hat * chi_1 + spin_2 = s2_hat * chi_2 + + return iota, *spin_1, *spin_2 diff --git a/bilby/gw/source.py b/bilby/gw/source.py index cc08d2d65..f8e732046 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -618,14 +618,15 @@ def _base_lal_cbc_fd_waveform( (frequency_array <= maximum_frequency)) luminosity_distance = luminosity_distance * 1e6 * utils.parsec - mass_1 = mass_1 * utils.solar_mass - mass_2 = mass_2 * utils.solar_mass iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_lalsimulation_spins( theta_jn=theta_jn, phi_jl=phi_jl, tilt_1=tilt_1, tilt_2=tilt_2, phi_12=phi_12, a_1=a_1, a_2=a_2, mass_1=mass_1, mass_2=mass_2, reference_frequency=reference_frequency, phase=phase) + mass_1 = mass_1 * utils.solar_mass + mass_2 = mass_2 * utils.solar_mass + longitude_ascending_nodes = 0.0 mean_per_ano = 0.0 diff --git a/examples/gw_examples/injection_examples/jax_fast_tutorial.py b/examples/gw_examples/injection_examples/jax_fast_tutorial.py index 56b1b4d3a..8fe3f6eb1 100644 --- a/examples/gw_examples/injection_examples/jax_fast_tutorial.py +++ b/examples/gw_examples/injection_examples/jax_fast_tutorial.py @@ -24,30 +24,6 @@ jax.config.update("jax_enable_x64", True) -def bilby_to_ripple_spins( - theta_jn, - phi_jl, - tilt_1, - tilt_2, - phi_12, - a_1, - a_2, -): - """ - A simplified spherical to cartesian spin conversion function. - This is not equivalent to the method used in `bilby.gw.conversion` - which comes from `lalsimulation` and is not `JAX` compatible. - """ - iota = theta_jn - spin_1x = a_1 * jnp.sin(tilt_1) * jnp.cos(phi_jl) - spin_1y = a_1 * jnp.sin(tilt_1) * jnp.sin(phi_jl) - spin_1z = a_1 * jnp.cos(tilt_1) - spin_2x = a_2 * jnp.sin(tilt_2) * jnp.cos(phi_jl + phi_12) - spin_2y = a_2 * jnp.sin(tilt_2) * jnp.sin(phi_jl + phi_12) - spin_2z = a_2 * jnp.cos(tilt_2) - return iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z - - def ripple_bbh( frequency, mass_1, @@ -102,8 +78,8 @@ def ripple_bbh( dict Dictionary containing the plus and cross polarizations of the waveform. """ - iota, *cartesian_spins = bilby_to_ripple_spins( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2 + iota, *cartesian_spins = bilby.gw.geometry.transform_precessing_spins( + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, f_ref, phase ) frequencies = jnp.maximum(frequency, kwargs["minimum_frequency"]) theta = jnp.array( diff --git a/test/gw/geometry_test.py b/test/gw/geometry_test.py new file mode 100644 index 000000000..3d87c628e --- /dev/null +++ b/test/gw/geometry_test.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + + +@pytest.mark.array_backend +def test_transform_precessing_spins(xp): + """ + Verify that our port of this function matches the lalsimulation version. + """ + import lal + from bilby.core.prior import Uniform + from bilby.gw.prior import BBHPriorDict + from bilby.gw.geometry import transform_precessing_spins + from lalsimulation import SimInspiralTransformPrecessingNewInitialConditions + + priors = BBHPriorDict() + priors["mass_1"] = Uniform(1, 1000) + priors["mass_2"] = Uniform(1, 1000) + priors["reference_frequency"] = Uniform(10, 100) + + # some default priors are problematic for some array backends + for key in ["luminosity_distance", "chirp_mass", "mass_ratio"]: + del priors[key] + + for _ in range(100): + point = priors.sample(xp=xp) + bilby_transformed = np.asarray(transform_precessing_spins( + point["theta_jn"], + point["phi_jl"], + point["tilt_1"], + point["tilt_2"], + point["phi_12"], + point["a_1"], + point["a_2"], + point["mass_1"], + point["mass_2"], + point["reference_frequency"], + point["phase"], + )) + lalsim_transformed = np.asarray(SimInspiralTransformPrecessingNewInitialConditions( + float(point["theta_jn"]), + float(point["phi_jl"]), + float(point["tilt_1"]), + float(point["tilt_2"]), + float(point["phi_12"]), + float(point["a_1"]), + float(point["a_2"]), + float(point["mass_1"] * lal.MSUN_SI), + float(point["mass_2"] * lal.MSUN_SI), + float(point["reference_frequency"]), + float(point["phase"]), + )) + np.testing.assert_allclose(bilby_transformed, lalsim_transformed, rtol=1e-10) + + +@pytest.mark.array_backend +def test_transform_precessing_spins_vectorized(xp): + """ + Run the tests with vectorization, note that this returns a tuple of arrays. + """ + import lal + from bilby.core.prior import Uniform + from bilby.gw.prior import BBHPriorDict + from bilby.gw.geometry import transform_precessing_spins + from lalsimulation import SimInspiralTransformPrecessingNewInitialConditions + + priors = BBHPriorDict() + priors["mass_1"] = Uniform(1, 1000) + priors["mass_2"] = Uniform(1, 1000) + priors["reference_frequency"] = Uniform(10, 100) + + # some default priors are problematic for some array backends + for key in ["luminosity_distance", "chirp_mass", "mass_ratio"]: + del priors[key] + + points = priors.sample(100, xp=xp) + bilby_transformed = np.asarray(transform_precessing_spins( + points["theta_jn"], + points["phi_jl"], + points["tilt_1"], + points["tilt_2"], + points["phi_12"], + points["a_1"], + points["a_2"], + points["mass_1"], + points["mass_2"], + points["reference_frequency"], + points["phase"], + )) + lalsim_transformed = list() + for ii in range(len(points["theta_jn"])): + point = {key: points[key][ii] for key in points.keys()} + lalsim_transformed.append(np.asarray(SimInspiralTransformPrecessingNewInitialConditions( + float(point["theta_jn"]), + float(point["phi_jl"]), + float(point["tilt_1"]), + float(point["tilt_2"]), + float(point["phi_12"]), + float(point["a_1"]), + float(point["a_2"]), + float(point["mass_1"] * lal.MSUN_SI), + float(point["mass_2"] * lal.MSUN_SI), + float(point["reference_frequency"]), + float(point["phase"]), + ))) + lalsim_transformed = np.asarray(lalsim_transformed).T + np.testing.assert_allclose(bilby_transformed, lalsim_transformed, rtol=1e-10)