@@ -308,7 +308,11 @@ def __init__(self, lookup: sc.DataArray, distance_unit: str, time_unit: str):
308
308
)
309
309
310
310
def __call__ (
311
- self , ltotal : sc .Variable , event_time_offset : sc .Variable
311
+ self ,
312
+ ltotal : sc .Variable ,
313
+ event_time_offset : sc .Variable ,
314
+ pulse_period : sc .Variable ,
315
+ pulse_index : sc .Variable | None = None ,
312
316
) -> sc .Variable :
313
317
if ltotal .unit != self ._distance_unit :
314
318
raise sc .UnitError (
@@ -326,7 +330,12 @@ def __call__(
326
330
327
331
return sc .array (
328
332
dims = out_dims ,
329
- values = self ._interpolator (times = event_time_offset , distances = ltotal ),
333
+ values = self ._interpolator (
334
+ times = event_time_offset ,
335
+ distances = ltotal ,
336
+ pulse_index = pulse_index .values if pulse_index is not None else None ,
337
+ pulse_period = pulse_period .value ,
338
+ ),
330
339
unit = self ._time_unit ,
331
340
)
332
341
@@ -359,7 +368,11 @@ def _time_of_flight_data_histogram(
359
368
interp = TofInterpolator (lookup , distance_unit = ltotal .unit , time_unit = eto_unit )
360
369
361
370
# Compute time-of-flight of the bin edges using the interpolator
362
- tofs = interp (ltotal = ltotal .broadcast (sizes = etos .sizes ), event_time_offset = etos )
371
+ tofs = interp (
372
+ ltotal = ltotal .broadcast (sizes = etos .sizes ),
373
+ event_time_offset = etos ,
374
+ pulse_period = pulse_period ,
375
+ )
363
376
364
377
return rebinned .assign_coords (tof = tofs )
365
378
@@ -418,11 +431,13 @@ def _guess_pulse_stride_offset(
418
431
values = event_time_offset .values [inds ],
419
432
unit = event_time_offset .unit ,
420
433
)
421
- pulse_period = pulse_period .to (unit = etos .unit )
422
434
for i in range (pulse_stride ):
423
435
pulse_inds = (pulse_index + i ) % pulse_stride
424
436
tofs [i ] = interp (
425
- ltotal = ltotal , event_time_offset = etos + pulse_inds * pulse_period
437
+ ltotal = ltotal ,
438
+ event_time_offset = etos ,
439
+ pulse_index = pulse_inds ,
440
+ pulse_period = pulse_period ,
426
441
)
427
442
# Find the entry in the list with the least number of nan values
428
443
return sorted (tofs , key = lambda x : sc .isnan (tofs [x ]).sum ())[0 ]
@@ -446,12 +461,12 @@ def _time_of_flight_data_events(
446
461
ltotal = sc .bins_like (etos , ltotal ).bins .constituents ["data" ]
447
462
etos = etos .bins .constituents ["data" ]
448
463
449
- # Compute a pulse index for every event: it is the index of the pulse within a
450
- # frame period. When there is no pulse skipping, those are all zero. When there is
451
- # pulse skipping, the index ranges from zero to pulse_stride - 1.
452
- if pulse_stride == 1 :
453
- pulse_index = sc . zeros ( sizes = etos . sizes )
454
- else :
464
+ pulse_index = None
465
+ pulse_period = pulse_period . to ( unit = eto_unit )
466
+
467
+ if pulse_stride > 1 :
468
+ # Compute a pulse index for every event: it is the index of the pulse within a
469
+ # frame period. The index ranges from zero to pulse_stride - 1.
455
470
etz_unit = 'ns'
456
471
etz = (
457
472
da .bins .coords ["event_time_zero" ]
@@ -495,7 +510,9 @@ def _time_of_flight_data_events(
495
510
# Compute time-of-flight for all neutrons using the interpolator
496
511
tofs = interp (
497
512
ltotal = ltotal ,
498
- event_time_offset = etos + pulse_index * pulse_period .to (unit = eto_unit ),
513
+ event_time_offset = etos ,
514
+ pulse_index = pulse_index ,
515
+ pulse_period = pulse_period ,
499
516
)
500
517
501
518
parts = da .bins .constituents
0 commit comments