Support non-numpy array backends#886
Conversation
ea348fa to
771a8a9
Compare
|
This is now ready for review. There are a lot of changes, but most of them are essentially Bilby can once again be installed without I've managed to keep test changes minimal:
|
This required making some changes to the tests for conditional dicts as I've changed the output types and the backend introspection doesn't work on dict_items for some reason
GregoryAshton
left a comment
There was a problem hiding this comment.
Okay, I got through about 60% of the diff and I'm pausing here so will submit the questions so far.
| @@ -0,0 +1,50 @@ | |||
| import array_api_compat as aac | |||
There was a problem hiding this comment.
Is the idea that this file provides compatibility between all the different array types? Dare I say it, but it feels like this should be a whole python package in itself...
| from .utils import BackendNotImplementedError | ||
|
|
||
|
|
||
| def erfinv_import(xp): |
There was a problem hiding this comment.
All of these functions would benefit from a docstring to explain they do the import given the type of array backend.
| """ | ||
| at_peak = (val == self.peak) | ||
| return np.nan_to_num(np.multiply(at_peak, np.inf)) | ||
| return at_peak * 1.0 |
There was a problem hiding this comment.
May be wise to add a comment here and in the other instance in case someone "cleans it up" down the road.
| _prob[idx] = np.exp(-(np.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ | ||
| / np.sqrt(2 * np.pi) / val[idx] / self.sigma | ||
| return _prob | ||
| return xp.exp(self.ln_prob(val, xp=xp)) |
There was a problem hiding this comment.
I presume there was some reason we handled things in a complicated way before.. was it just bad/lazy coding or where we fixing some edge case? Anyone recall..
| _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) | ||
| return _cdf | ||
| with np.errstate(divide="ignore"): | ||
| return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) |
There was a problem hiding this comment.
Ah okay - are the bounds being implemented here? But, I don't see the upper bound being implemented.
| log_l = np.sum(- (self.residual(parameters) / sigma)**2 / 2 - | ||
| np.log(2 * np.pi * sigma**2) / 2) | ||
| log_l = xp.sum(- (self.residual(parameters) / sigma)**2 / 2 - | ||
| xp.log(xp.asarray(2 * np.pi * sigma**2)) / 2) |
There was a problem hiding this comment.
I'm seeing this in a lot of places. It isn't required by numpy. Presumably it is required by jax or something else?
| xp = array_module(waveform_polarizations) | ||
| if frequencies is None: | ||
| frequencies = self.frequency_array[self.frequency_mask] | ||
| # frequencies = self.frequency_array[self.frequency_mask] |
There was a problem hiding this comment.
Is there a reason to leave this commented out?
|
|
||
| signal[mode] = waveform_polarizations[mode] * det_response | ||
| signal_ifo = sum(signal.values()) * mask | ||
| signal[mode] = waveform_polarizations[mode] * mask * det_response |
There was a problem hiding this comment.
It looks like this is changing the way the mask is being used. From operating on a view to operating on the full array but zeroing the False cases. Is that correct?
I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.
The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.
The general guiding principles are:
array-apispecification andscipyinteroperabilityThe primary changes so far are:
Changed behaviour:
Remaining issues:
bilby.gw.jaxstufffile should be removed and relevant functionality be moved elsewhere, it's currently just used for testing