Skip to content

Commit 4b6a070

Browse files
authored
Merge pull request #231 from scipp/interpolator_optmizations
Time-of-flight interpolator optmizations
2 parents e79ffa8 + f45aa55 commit 4b6a070

File tree

4 files changed

+71
-19
lines changed

4 files changed

+71
-19
lines changed

src/ess/reduce/time_of_flight/eto_to_tof.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,11 @@ def __init__(self, lookup: sc.DataArray, distance_unit: str, time_unit: str):
308308
)
309309

310310
def __call__(
311-
self, ltotal: sc.Variable, event_time_offset: sc.Variable
311+
self,
312+
ltotal: sc.Variable,
313+
event_time_offset: sc.Variable,
314+
pulse_period: sc.Variable,
315+
pulse_index: sc.Variable | None = None,
312316
) -> sc.Variable:
313317
if ltotal.unit != self._distance_unit:
314318
raise sc.UnitError(
@@ -326,7 +330,12 @@ def __call__(
326330

327331
return sc.array(
328332
dims=out_dims,
329-
values=self._interpolator(times=event_time_offset, distances=ltotal),
333+
values=self._interpolator(
334+
times=event_time_offset,
335+
distances=ltotal,
336+
pulse_index=pulse_index.values if pulse_index is not None else None,
337+
pulse_period=pulse_period.value,
338+
),
330339
unit=self._time_unit,
331340
)
332341

@@ -359,7 +368,11 @@ def _time_of_flight_data_histogram(
359368
interp = TofInterpolator(lookup, distance_unit=ltotal.unit, time_unit=eto_unit)
360369

361370
# Compute time-of-flight of the bin edges using the interpolator
362-
tofs = interp(ltotal=ltotal.broadcast(sizes=etos.sizes), event_time_offset=etos)
371+
tofs = interp(
372+
ltotal=ltotal.broadcast(sizes=etos.sizes),
373+
event_time_offset=etos,
374+
pulse_period=pulse_period,
375+
)
363376

364377
return rebinned.assign_coords(tof=tofs)
365378

@@ -418,11 +431,13 @@ def _guess_pulse_stride_offset(
418431
values=event_time_offset.values[inds],
419432
unit=event_time_offset.unit,
420433
)
421-
pulse_period = pulse_period.to(unit=etos.unit)
422434
for i in range(pulse_stride):
423435
pulse_inds = (pulse_index + i) % pulse_stride
424436
tofs[i] = interp(
425-
ltotal=ltotal, event_time_offset=etos + pulse_inds * pulse_period
437+
ltotal=ltotal,
438+
event_time_offset=etos,
439+
pulse_index=pulse_inds,
440+
pulse_period=pulse_period,
426441
)
427442
# Find the entry in the list with the least number of nan values
428443
return sorted(tofs, key=lambda x: sc.isnan(tofs[x]).sum())[0]
@@ -446,12 +461,12 @@ def _time_of_flight_data_events(
446461
ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"]
447462
etos = etos.bins.constituents["data"]
448463

449-
# Compute a pulse index for every event: it is the index of the pulse within a
450-
# frame period. When there is no pulse skipping, those are all zero. When there is
451-
# pulse skipping, the index ranges from zero to pulse_stride - 1.
452-
if pulse_stride == 1:
453-
pulse_index = sc.zeros(sizes=etos.sizes)
454-
else:
464+
pulse_index = None
465+
pulse_period = pulse_period.to(unit=eto_unit)
466+
467+
if pulse_stride > 1:
468+
# Compute a pulse index for every event: it is the index of the pulse within a
469+
# frame period. The index ranges from zero to pulse_stride - 1.
455470
etz_unit = 'ns'
456471
etz = (
457472
da.bins.coords["event_time_zero"]
@@ -495,7 +510,9 @@ def _time_of_flight_data_events(
495510
# Compute time-of-flight for all neutrons using the interpolator
496511
tofs = interp(
497512
ltotal=ltotal,
498-
event_time_offset=etos + pulse_index * pulse_period.to(unit=eto_unit),
513+
event_time_offset=etos,
514+
pulse_index=pulse_index,
515+
pulse_period=pulse_period,
499516
)
500517

501518
parts = da.bins.constituents

src/ess/reduce/time_of_flight/interpolator_numba.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def interpolate(
1111
values: np.ndarray,
1212
xp: np.ndarray,
1313
yp: np.ndarray,
14+
xoffset: np.ndarray | None,
15+
deltax: float,
1416
fill_value: float,
1517
out: np.ndarray,
1618
):
@@ -29,6 +31,10 @@ def interpolate(
2931
1D array of x-coordinates where to interpolate (size N).
3032
yp:
3133
1D array of y-coordinates where to interpolate (size N).
34+
xoffset:
35+
1D array of integer offsets to apply to the x-coordinates (size N).
36+
deltax:
37+
Multiplier to apply to the integer offsets (i.e. the step size).
3238
fill_value:
3339
Value to use for points outside of the grid.
3440
out:
@@ -52,7 +58,7 @@ def interpolate(
5258
norm = one_over_dx * one_over_dy
5359

5460
for i in prange(npoints):
55-
xx = xp[i]
61+
xx = xp[i] + (xoffset[i] * deltax if xoffset is not None else 0.0)
5662
yy = yp[i]
5763

5864
if (xx < xmin) or (xx > xmax) or (yy < ymin) or (yy > ymax):
@@ -108,14 +114,22 @@ def __init__(
108114
self.values = values
109115
self.fill_value = fill_value
110116

111-
def __call__(self, times: np.ndarray, distances: np.ndarray) -> np.ndarray:
117+
def __call__(
118+
self,
119+
times: np.ndarray,
120+
distances: np.ndarray,
121+
pulse_period: float = 0.0,
122+
pulse_index: np.ndarray | None = None,
123+
) -> np.ndarray:
112124
out = np.empty_like(times)
113125
interpolate(
114126
x=self.time_edges,
115127
y=self.distance_edges,
116128
values=self.values,
117129
xp=times,
118130
yp=distances,
131+
xoffset=pulse_index,
132+
deltax=pulse_period,
119133
fill_value=self.fill_value,
120134
out=out,
121135
)

src/ess/reduce/time_of_flight/interpolator_scipy.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,21 @@ def __init__(
3939
from scipy.interpolate import RegularGridInterpolator
4040

4141
self._interp = RegularGridInterpolator(
42-
(
43-
distance_edges,
44-
time_edges,
45-
),
42+
(distance_edges, time_edges),
4643
values,
4744
method=method,
4845
bounds_error=bounds_error,
4946
fill_value=fill_value,
5047
**kwargs,
5148
)
5249

53-
def __call__(self, times: np.ndarray, distances: np.ndarray) -> np.ndarray:
50+
def __call__(
51+
self,
52+
times: np.ndarray,
53+
distances: np.ndarray,
54+
pulse_period: float = 0.0,
55+
pulse_index: np.ndarray | None = None,
56+
) -> np.ndarray:
57+
if pulse_index is not None:
58+
times = times + (pulse_index * pulse_period)
5459
return self._interp((distances, times))

tests/time_of_flight/interpolator_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ def test_numba_and_scipy_interpolators_yield_same_results():
5353
assert np.allclose(numba_result, scipy_result)
5454

5555

56+
def test_numba_and_scipy_interpolators_yield_same_results_with_pulse_offset():
57+
numba_interp, scipy_interp = _make_interpolators()
58+
59+
rng = np.random.default_rng(seed=42)
60+
npoints = 1000
61+
times = rng.uniform(0, 71, npoints)
62+
distances = rng.uniform(40, 70, npoints)
63+
offsets = rng.uniform(0, 2, npoints)
64+
period = 1.0
65+
66+
numba_result = numba_interp(times, distances, period, offsets)
67+
scipy_result = scipy_interp(times, distances, period, offsets)
68+
69+
assert np.allclose(numba_result, scipy_result, equal_nan=True)
70+
71+
5672
def test_numba_and_scipy_interpolators_yield_same_results_with_out_of_bounds():
5773
numba_interp, scipy_interp = _make_interpolators()
5874

0 commit comments

Comments
 (0)