Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
5faa432
FEAT: enable backend switching for base gravitational-wave transient …
ColmTalbot Oct 25, 2024
208f227
FEAT: support multiband and relative binning likelihoods
ColmTalbot Oct 25, 2024
c4d9bdf
FEAT: make more conversions backend agnostic
ColmTalbot Oct 26, 2024
cd78b34
FEAT: use more normal conversions
ColmTalbot Oct 28, 2024
68abf3c
FEAT: move backend switching code to bilby
ColmTalbot Nov 13, 2024
47041a1
FEAT: make core prior backend agnostic
ColmTalbot Nov 14, 2024
e816198
FEAT: make non-numpy arrays serializable
ColmTalbot Nov 14, 2024
7a785c2
BUG: fix some array conversion methods
ColmTalbot Nov 14, 2024
7ebf340
DEV: some more prior agnosticism
ColmTalbot Dec 11, 2024
b558ea6
TEST: make all prior tests run
ColmTalbot Dec 12, 2024
af8d604
DEV: move some jax functionality to compat
ColmTalbot Jan 25, 2025
5b5fa6b
REFACTOR: use array backend for ln_i0
ColmTalbot Jan 25, 2025
c52be69
make distance marginalizatio backend transparent
ColmTalbot Jan 25, 2025
025e3d5
DEV: some more prior dict array refactoring
ColmTalbot Jan 25, 2025
47baf2c
fix jax logic for distance marginalization
ColmTalbot Jan 29, 2025
b501d83
improve efficiency of setting up multibanding
ColmTalbot Jan 29, 2025
ca7e4f8
make high-dimensional gaussians jax compatible
ColmTalbot Jan 29, 2025
a8a9b98
make cubic spline calibration work with jax backend
ColmTalbot Jan 30, 2025
2117df4
BUG: fix linspace calls
ColmTalbot Feb 4, 2025
3e46a9b
ENH: fix bottleneck in relative binning for JAX
ColmTalbot Feb 4, 2025
1d891cc
ENH: make interpolated prior backend friendly
ColmTalbot Feb 4, 2025
a48544b
REFACTOR: refactor backend-specific interpolation code
ColmTalbot Feb 5, 2025
15edfba
ENH: make sine gaussian model backend independent
ColmTalbot Feb 5, 2025
5ce19b3
ENH: make roq likelihood backend independent
ColmTalbot Feb 5, 2025
8c7e992
BUG: fix roq slicing
ColmTalbot Feb 5, 2025
cf270d9
FEAT: make condition chi evaluable
ColmTalbot Jun 3, 2025
f40e845
MAINT: make whitening work for non-numpy
ColmTalbot Jun 12, 2025
2f17eee
EXAMPLE: update jax example
ColmTalbot Aug 20, 2025
ee25959
BUG: fix interpax interpolation method
ColmTalbot Aug 20, 2025
21d2306
REFACTOR: update variable backend for new parameter method
ColmTalbot Oct 2, 2025
3012c99
some simplifications of array transparency
ColmTalbot Oct 2, 2025
913428e
HYPER: make hyperparameter likelihood handle array backends
ColmTalbot Oct 2, 2025
2fa1752
MAINT: switch back to bilby_cython
ColmTalbot Dec 11, 2025
8824196
TYPO: fix typo in multiband time-marginalized likelihood
ColmTalbot Dec 22, 2025
dfb3256
MAINT: removed unused import
ColmTalbot Dec 22, 2025
1c03740
BUG: add explicit array cast in conversion
ColmTalbot Dec 22, 2025
0f237d6
REFACTOR: some refactoring of array edge cases
ColmTalbot Dec 22, 2025
f54dfa1
MAINT: removed extra ripple code
ColmTalbot Dec 22, 2025
0e7fb3e
REFACTOR: make bilby_cython an optional dependency
ColmTalbot Dec 22, 2025
8adbdbd
FMT: formatting fixes
ColmTalbot Dec 22, 2025
822e08e
BUG: fix array introspection for conversion
ColmTalbot Jan 21, 2026
d11b2c4
REFACTOR: make parameters for waveform generator more strict
ColmTalbot Jan 21, 2026
69abdfb
BUG: fix core likelihood tests
ColmTalbot Jan 21, 2026
b3c38ba
BUG: fix calibration calculations
ColmTalbot Jan 22, 2026
23479c8
EXAMPLE: update jax fast tutorial
ColmTalbot Jan 22, 2026
17114a2
TST: refactor marginalization tests to be less restrictive
ColmTalbot Jan 22, 2026
a98fb43
DOC: update jittedlikelihood docstring
ColmTalbot Jan 22, 2026
f476930
TEST: speed up initializing prior tests
ColmTalbot Jan 22, 2026
e4c96c3
BUG: fix some test failures
ColmTalbot Jan 22, 2026
a85838f
BUG: fix conditional+joint prior rescaling
ColmTalbot Jan 22, 2026
93cda56
BUG: fix some gnarly conversion corner cases
ColmTalbot Jan 22, 2026
0316acc
BUG: fix multiband likelihood
ColmTalbot Jan 22, 2026
2c3f8fb
BUG: fix bug in array_namespace check
ColmTalbot Jan 22, 2026
5d10a8a
TEST: make sure healpix prior doesn't store state between calls
ColmTalbot Jan 22, 2026
27f2046
FMT: example formatting fixes
ColmTalbot Jan 22, 2026
91ee508
BUG: make sure indices don't overflow in roq
ColmTalbot Jan 22, 2026
5ddf3e3
BUG: fix multiband time marginalization setup
ColmTalbot Jan 23, 2026
79ae333
BUG: fix roq interpolation for out of bounds sample
ColmTalbot Jan 23, 2026
63b6f30
TYPO: fix typo in jax example
ColmTalbot Jan 23, 2026
23a3d79
REFACTOR: refactor more roq likelihood tests
ColmTalbot Jan 23, 2026
311ced4
MAINT: revert new conversions
ColmTalbot Jan 23, 2026
cb9703a
CI: fix selecting only non-windows os
ColmTalbot Jan 23, 2026
ad23f4f
MAINT: make sure compat subpackages are listed in pyproject
ColmTalbot Jan 23, 2026
5930568
TYPO: Fix package list formatting in pyproject.toml
ColmTalbot Jan 23, 2026
230f623
BUG: readd erroneously removed line
ColmTalbot Jan 23, 2026
f65e668
DOC: remove extraneous docstring
ColmTalbot Jan 23, 2026
6488bdb
Merge branch 'main' into bilback
ColmTalbot Jan 28, 2026
2213038
TEST: fix test failures
ColmTalbot Jan 29, 2026
a67b4ae
TEST: start adding jax tests
ColmTalbot Jan 31, 2026
080df9d
CI: add jax tests to CI
ColmTalbot Jan 31, 2026
2f9cd61
Merge branch 'main' into bilback
ColmTalbot Jan 31, 2026
164bc70
MAINT: add jax extras option
ColmTalbot Jan 31, 2026
2205fc2
Some more jax testing updates
ColmTalbot Jan 31, 2026
aec63af
MAINT: actually add jax requirements
ColmTalbot Jan 31, 2026
cc79c54
CI: don't trivially skip all tests...
ColmTalbot Jan 31, 2026
9d4e01a
Initial pass at making grid work with jax
ColmTalbot Jan 31, 2026
9d27356
TEST: add more jax test coverage
ColmTalbot Feb 1, 2026
4222906
FMT: precommit fixes
ColmTalbot Feb 2, 2026
4bb8805
TEST: fix jax tests
ColmTalbot Feb 2, 2026
f34646a
TEST: add basic gw conversion jax tests
ColmTalbot Feb 2, 2026
30b89a8
TEST: more debugging slab spike test
ColmTalbot Feb 2, 2026
43cf406
TEST: jax tests work again
ColmTalbot Feb 2, 2026
ba6f1ce
DOC: add initial doc page for array backend
ColmTalbot Feb 2, 2026
0a6f1e2
TEST: add a bunch of gw tests
ColmTalbot Feb 2, 2026
aafbf1b
DOC: fix doc page formatting
ColmTalbot Feb 2, 2026
8ca9978
FMT: fix formatting
ColmTalbot Feb 2, 2026
bdda315
BUG: fix typo in bilby_cython call
ColmTalbot Feb 2, 2026
5bf699e
BUG: fix list input for asd calculation
ColmTalbot Feb 2, 2026
acd22a3
FMT: fix syntax for array conversion and backend checks
ColmTalbot Feb 2, 2026
eea4f24
BUG: fix some broken formatting
ColmTalbot Feb 2, 2026
0a17a5d
FMT: fix formatting
ColmTalbot Feb 2, 2026
713ca68
BUG: fix bugs in testing
ColmTalbot Feb 2, 2026
e0c41db
Fix some more conversions
ColmTalbot Feb 2, 2026
5951229
Add pytorch core testing
ColmTalbot Feb 3, 2026
38cc5f6
FMT: run precommits
ColmTalbot Feb 3, 2026
65029de
Make torch fully tested
ColmTalbot Feb 3, 2026
1759f6a
FMT: pre-commit fix
ColmTalbot Feb 3, 2026
b74b838
TEST: fix torch roq tests
ColmTalbot Feb 3, 2026
f734a33
CI: prioritize torch tests
ColmTalbot Feb 3, 2026
602a48d
TEST: another attempt to fix torch tests
ColmTalbot Feb 3, 2026
ebb5ef2
Another attempt at fixing torch ROQ tests
ColmTalbot Feb 3, 2026
e21790d
Fix arrays of data setting
ColmTalbot Feb 3, 2026
721a033
BUG: fix some more roq array issues
ColmTalbot Feb 3, 2026
78e94da
Make ROQ calculations use correct array backend
ColmTalbot Feb 3, 2026
30dfd8c
BUG: fix a missing array case
ColmTalbot Feb 3, 2026
8159bb8
FMT: pre-commit fixes
ColmTalbot Feb 3, 2026
438ed64
CI: drop torch tests for python 3.10
ColmTalbot Feb 3, 2026
00bee6d
FMT: precommit fix
ColmTalbot Feb 3, 2026
f152742
TEST: exclude studentt tests for jax
ColmTalbot Feb 3, 2026
4cfa4fb
Add some more explicit array casts
ColmTalbot Feb 3, 2026
b638d8e
BUG: bug fixes for prior and gw likelihoods
ColmTalbot Feb 17, 2026
dc0ba85
BUG: fix array namespace for torch
ColmTalbot Feb 17, 2026
b6df5c3
Merge branch 'main' into bilback
ColmTalbot Feb 18, 2026
46f8cf5
Merge branch 'main' into bilback
ColmTalbot Feb 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .github/workflows/basic-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
# disable windows build test as bilby_cython is currently broken there
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -52,8 +51,8 @@ jobs:
python -c "import bilby.hyper"
python -c "import cli_bilby"
python test/import_test.py
# - if: ${{ matrix.os != "windows-latest" }}
# run: |
# for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do
# ${script} --help;
# done
- if: runner.os != 'Windows'
run: |
for script in $(pip show -f bilby | grep "bin\/" | xargs -I {} basename {}); do
${script} --help;
done
9 changes: 9 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ jobs:
- name: Run unit tests
run: |
python -m pytest --cov=bilby --cov-branch --durations 10 -ra --color yes --cov-report=xml --junitxml=pytest.xml
- name: Run jax-backend unit tests
run: |
python -m pip install .[jax]
SCIPY_ARRAY_API=1 pytest --array-backend jax --durations 10
- name: Run torch-backend unit tests
# there are scipy version issues with python 3.10 and torch
if: matrix.python.version > 3.10
run: |
SCIPY_ARRAY_API=1 pytest --array-backend torch --durations 10
- name: Run sampler tests
run: |
pytest test/integration/sampler_run_test.py --durations 10 -v
Expand Down
Empty file added bilby/compat/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions bilby/compat/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax
import jax.numpy as jnp
from ..core.likelihood import Likelihood


class JittedLikelihood(Likelihood):
"""
A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`.

Parameters
==========
likelihood: bilby.core.likelihood.Likelihood
The likelihood to wrap.
cast_to_float: bool
Whether to return a float instead of a :code:`jax.Array`.
"""

def __init__(self, likelihood, cast_to_float=True):
self._likelihood = likelihood
self._ll = jax.jit(likelihood.log_likelihood)
self._llr = jax.jit(likelihood.log_likelihood_ratio)
self.cast_to_float = cast_to_float
super().__init__()

def __getattr__(self, name):
return getattr(self._likelihood, name)

def log_likelihood(self, parameters):
parameters = {k: jnp.array(v) for k, v in parameters.items()}
ln_l = self._ll(parameters)
if self.cast_to_float:
ln_l = float(ln_l)
return ln_l

def log_likelihood_ratio(self, parameters):
parameters = {k: jnp.array(v) for k, v in parameters.items()}
ln_l = self._llr(parameters)
if self.cast_to_float:
ln_l = float(ln_l)
return ln_l
50 changes: 50 additions & 0 deletions bilby/compat/patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import array_api_compat as aac
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that this file provides compatibility between all the different array types? Dare I say it, but it feels like this should be a whole python package in itself...


from .utils import BackendNotImplementedError


def erfinv_import(xp):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these functions would benefit from a docstring to explain they do the import given the type of array backend.

if aac.is_numpy_namespace(xp):
from scipy.special import erfinv
elif aac.is_jax_namespace(xp):
from jax.scipy.special import erfinv
elif aac.is_torch_namespace(xp):
from torch.special import erfinv
elif aac.is_cupy_namespace(xp):
from cupyx.scipy.special import erfinv
else:
raise BackendNotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be useful to include the backend in the error.

return erfinv


def multivariate_logpdf(xp, mean, cov):
if aac.is_numpy_namespace(xp):
from scipy.stats import multivariate_normal

logpdf = multivariate_normal(mean=mean, cov=cov).logpdf
elif aac.is_jax_namespace(xp):
from functools import partial
from jax.scipy.stats.multivariate_normal import logpdf

logpdf = partial(logpdf, mean=mean, cov=cov)
elif aac.is_torch_namespace(xp):
from torch.distributions.multivariate_normal import MultivariateNormal

mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov))
logpdf = mvn.log_prob
else:
raise BackendNotImplementedError
return logpdf


def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=None):
if xp is None:
xp = aac.get_namespace(a)

# the scipy version of logsumexp cannot be vmapped
if aac.is_jax_namespace(xp):
from jax.scipy.special import logsumexp as lse
else:
from scipy.special import logsumexp as lse

return lse(a=a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign)
4 changes: 4 additions & 0 deletions bilby/compat/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

Real = float | int
ArrayLike = np.ndarray | list | tuple
105 changes: 105 additions & 0 deletions bilby/compat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import inspect
from collections.abc import Iterable

import numpy as np
from array_api_compat import array_namespace

from ..core.utils.log import logger

__all__ = ["array_module", "promote_to_array"]


def array_module(arr):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would benefit from a doc-string

if isinstance(arr, tuple) and len(arr) == 1:
arr = arr[0]
try:
return array_namespace(arr)
except TypeError:
if isinstance(arr, dict):
try:
return array_namespace(*[val for val in arr.values() if not isinstance(val, str)])
except TypeError:
return np
elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable):
try:
return array_namespace(*arr)
except TypeError:
return np
elif arr.__class__.__module__ == "builtins":
return np
elif arr.__module__.startswith("pandas"):
return np
else:
logger.warning(
f"Unknown array module for type: {type(arr)} Defaulting to numpy."
)
return np


def promote_to_array(args, backend, skip=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you thought about how devices would be handled here?

Moving arrays to a from GPUs can sometimes require more than just calling array.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a doc-string

if skip is None:
skip = len(args)
else:
skip = len(args) - skip
if backend.__name__ != "numpy":
args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:]
return args


def xp_wrap(func, no_xp=False):
"""
A decorator that will figure out the array module from the input
arguments and pass it to the function as the 'xp' keyword argument.

Parameters
==========
func: function
The function to be decorated.
no_xp: bool
If True, the decorator will not attempt to add the 'xp' keyword
argument and so the wrapper is a no-op.

Returns
=======
function
The decorated function.
"""
def parse_args_kwargs_for_xp(*args, xp=None, **kwargs):
if not no_xp and xp is None:
try:
# if the user specified the target arrays in kwargs
# we need to be able to support this, if there is
# only one kwargs, pass it through alone, this is
# sometimes a dictionary of arrays so this is needed
# to remove a level of nesting
if len(args) > 0:
xp = array_module(args)
elif len(kwargs) == 1:
xp = array_module(next(iter(kwargs.values())))
elif len(kwargs) > 1:
xp = array_module(kwargs)
else:
xp = np
kwargs["xp"] = xp
except TypeError as e:
print("type failed", e)
kwargs["xp"] = np
elif not no_xp:
kwargs["xp"] = xp
return args, kwargs

sig = inspect.signature(func)
if any(name in sig.parameters for name in ("self", "cls")):
def wrapped(self, *args, xp=None, **kwargs):
args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs)
return func(self, *args, **kwargs)
else:
def wrapped(*args, xp=None, **kwargs):
args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs)
return func(*args, **kwargs)

return wrapped


class BackendNotImplementedError(NotImplementedError):
pass
Loading
Loading