-
Notifications
You must be signed in to change notification settings - Fork 123
Support non-numpy array backends #886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5faa432
208f227
c4d9bdf
cd78b34
68abf3c
47041a1
e816198
7a785c2
7ebf340
b558ea6
af8d604
5b5fa6b
c52be69
025e3d5
47baf2c
b501d83
ca7e4f8
a8a9b98
2117df4
3e46a9b
1d891cc
a48544b
15edfba
5ce19b3
8c7e992
cf270d9
f40e845
2f17eee
ee25959
21d2306
3012c99
913428e
2fa1752
8824196
dfb3256
1c03740
0f237d6
f54dfa1
0e7fb3e
8adbdbd
822e08e
d11b2c4
69abdfb
b3c38ba
23479c8
17114a2
a98fb43
f476930
e4c96c3
a85838f
93cda56
0316acc
2c3f8fb
5d10a8a
27f2046
91ee508
5ddf3e3
79ae333
63b6f30
23a3d79
311ced4
cb9703a
ad23f4f
5930568
230f623
f65e668
6488bdb
2213038
a67b4ae
080df9d
2f9cd61
164bc70
2205fc2
aec63af
cc79c54
9d4e01a
9d27356
4222906
4bb8805
f34646a
30b89a8
43cf406
ba6f1ce
0a6f1e2
aafbf1b
8ca9978
bdda315
5bf699e
acd22a3
eea4f24
0a17a5d
713ca68
e0c41db
5951229
38cc5f6
65029de
1759f6a
b74b838
f734a33
602a48d
ebb5ef2
e21790d
721a033
78e94da
30dfd8c
8159bb8
438ed64
00bee6d
f152742
4cfa4fb
b638d8e
dc0ba85
b6df5c3
46f8cf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from ..core.likelihood import Likelihood | ||
|
|
||
|
|
||
| class JittedLikelihood(Likelihood): | ||
| """ | ||
| A wrapper to just-in-time compile a :code:`Bilby` likelihood for use with :code:`jax`. | ||
|
|
||
| Parameters | ||
| ========== | ||
| likelihood: bilby.core.likelihood.Likelihood | ||
| The likelihood to wrap. | ||
| cast_to_float: bool | ||
| Whether to return a float instead of a :code:`jax.Array`. | ||
| """ | ||
|
|
||
| def __init__(self, likelihood, cast_to_float=True): | ||
| self._likelihood = likelihood | ||
| self._ll = jax.jit(likelihood.log_likelihood) | ||
| self._llr = jax.jit(likelihood.log_likelihood_ratio) | ||
| self.cast_to_float = cast_to_float | ||
| super().__init__() | ||
|
|
||
| def __getattr__(self, name): | ||
| return getattr(self._likelihood, name) | ||
|
|
||
| def log_likelihood(self, parameters): | ||
| parameters = {k: jnp.array(v) for k, v in parameters.items()} | ||
| ln_l = self._ll(parameters) | ||
| if self.cast_to_float: | ||
| ln_l = float(ln_l) | ||
| return ln_l | ||
|
|
||
| def log_likelihood_ratio(self, parameters): | ||
| parameters = {k: jnp.array(v) for k, v in parameters.items()} | ||
| ln_l = self._llr(parameters) | ||
| if self.cast_to_float: | ||
| ln_l = float(ln_l) | ||
| return ln_l |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import array_api_compat as aac | ||
|
|
||
| from .utils import BackendNotImplementedError | ||
|
|
||
|
|
||
| def erfinv_import(xp): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of these functions would benefit from a docstring to explain they do the import given the type of array backend. |
||
| if aac.is_numpy_namespace(xp): | ||
| from scipy.special import erfinv | ||
| elif aac.is_jax_namespace(xp): | ||
| from jax.scipy.special import erfinv | ||
| elif aac.is_torch_namespace(xp): | ||
| from torch.special import erfinv | ||
| elif aac.is_cupy_namespace(xp): | ||
| from cupyx.scipy.special import erfinv | ||
| else: | ||
| raise BackendNotImplementedError | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be useful to include the backend in the error. |
||
| return erfinv | ||
|
|
||
|
|
||
| def multivariate_logpdf(xp, mean, cov): | ||
| if aac.is_numpy_namespace(xp): | ||
| from scipy.stats import multivariate_normal | ||
|
|
||
| logpdf = multivariate_normal(mean=mean, cov=cov).logpdf | ||
| elif aac.is_jax_namespace(xp): | ||
| from functools import partial | ||
| from jax.scipy.stats.multivariate_normal import logpdf | ||
|
|
||
| logpdf = partial(logpdf, mean=mean, cov=cov) | ||
| elif aac.is_torch_namespace(xp): | ||
| from torch.distributions.multivariate_normal import MultivariateNormal | ||
|
|
||
| mvn = MultivariateNormal(loc=mean, covariance_matrix=xp.asarray(cov)) | ||
| logpdf = mvn.log_prob | ||
| else: | ||
| raise BackendNotImplementedError | ||
| return logpdf | ||
|
|
||
|
|
||
| def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, *, xp=None): | ||
| if xp is None: | ||
| xp = aac.get_namespace(a) | ||
|
|
||
| # the scipy version of logsumexp cannot be vmapped | ||
| if aac.is_jax_namespace(xp): | ||
| from jax.scipy.special import logsumexp as lse | ||
| else: | ||
| from scipy.special import logsumexp as lse | ||
|
|
||
| return lse(a=a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| import numpy as np | ||
|
|
||
| Real = float | int | ||
| ArrayLike = np.ndarray | list | tuple |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| import inspect | ||
| from collections.abc import Iterable | ||
|
|
||
| import numpy as np | ||
| from array_api_compat import array_namespace | ||
|
|
||
| from ..core.utils.log import logger | ||
|
|
||
| __all__ = ["array_module", "promote_to_array"] | ||
|
|
||
|
|
||
| def array_module(arr): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would benefit from a doc-string |
||
| if isinstance(arr, tuple) and len(arr) == 1: | ||
| arr = arr[0] | ||
| try: | ||
| return array_namespace(arr) | ||
| except TypeError: | ||
| if isinstance(arr, dict): | ||
| try: | ||
| return array_namespace(*[val for val in arr.values() if not isinstance(val, str)]) | ||
| except TypeError: | ||
| return np | ||
| elif arr.__class__.__module__ == "builtins" and isinstance(arr, Iterable): | ||
| try: | ||
| return array_namespace(*arr) | ||
| except TypeError: | ||
| return np | ||
| elif arr.__class__.__module__ == "builtins": | ||
| return np | ||
| elif arr.__module__.startswith("pandas"): | ||
| return np | ||
| else: | ||
| logger.warning( | ||
| f"Unknown array module for type: {type(arr)} Defaulting to numpy." | ||
| ) | ||
| return np | ||
|
|
||
|
|
||
| def promote_to_array(args, backend, skip=None): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you thought about how devices would be handled here? Moving arrays to a from GPUs can sometimes require more than just calling
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest adding a doc-string |
||
| if skip is None: | ||
| skip = len(args) | ||
| else: | ||
| skip = len(args) - skip | ||
| if backend.__name__ != "numpy": | ||
| args = tuple(backend.array(arg) for arg in args[:skip]) + args[skip:] | ||
| return args | ||
|
|
||
|
|
||
| def xp_wrap(func, no_xp=False): | ||
| """ | ||
| A decorator that will figure out the array module from the input | ||
| arguments and pass it to the function as the 'xp' keyword argument. | ||
|
|
||
| Parameters | ||
| ========== | ||
| func: function | ||
| The function to be decorated. | ||
| no_xp: bool | ||
| If True, the decorator will not attempt to add the 'xp' keyword | ||
| argument and so the wrapper is a no-op. | ||
|
|
||
| Returns | ||
| ======= | ||
| function | ||
| The decorated function. | ||
| """ | ||
| def parse_args_kwargs_for_xp(*args, xp=None, **kwargs): | ||
| if not no_xp and xp is None: | ||
| try: | ||
| # if the user specified the target arrays in kwargs | ||
| # we need to be able to support this, if there is | ||
| # only one kwargs, pass it through alone, this is | ||
| # sometimes a dictionary of arrays so this is needed | ||
| # to remove a level of nesting | ||
| if len(args) > 0: | ||
| xp = array_module(args) | ||
| elif len(kwargs) == 1: | ||
| xp = array_module(next(iter(kwargs.values()))) | ||
| elif len(kwargs) > 1: | ||
| xp = array_module(kwargs) | ||
| else: | ||
| xp = np | ||
| kwargs["xp"] = xp | ||
| except TypeError as e: | ||
| print("type failed", e) | ||
| kwargs["xp"] = np | ||
| elif not no_xp: | ||
| kwargs["xp"] = xp | ||
| return args, kwargs | ||
|
|
||
| sig = inspect.signature(func) | ||
| if any(name in sig.parameters for name in ("self", "cls")): | ||
| def wrapped(self, *args, xp=None, **kwargs): | ||
| args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) | ||
| return func(self, *args, **kwargs) | ||
| else: | ||
| def wrapped(*args, xp=None, **kwargs): | ||
| args, kwargs = parse_args_kwargs_for_xp(*args, xp=xp, **kwargs) | ||
| return func(*args, **kwargs) | ||
|
|
||
| return wrapped | ||
|
|
||
|
|
||
| class BackendNotImplementedError(NotImplementedError): | ||
| pass | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the idea that this file provides compatibility between all the different array types? Dare I say it, but it feels like this should be a whole python package in itself...