Skip to content

Commit 56852e6

Browse files
committed
fix: use scipp nansum
1 parent c7df6c2 commit 56852e6

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/ess/reflectometry/tools.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,19 @@ def combine_curves(
261261
if len({c.coords['Q'].unit for c in curves}) != 1:
262262
raise ValueError('The Q-coordinates must have the same unit for each curve')
263263

264-
r = _interpolate_on_qgrid(map(sc.values, curves), q_bin_edges).values
265-
v = _interpolate_on_qgrid(map(sc.variances, curves), q_bin_edges).values
264+
r = _interpolate_on_qgrid(map(sc.values, curves), q_bin_edges)
265+
v = _interpolate_on_qgrid(map(sc.variances, curves), q_bin_edges)
266266

267-
v[v == 0] = np.nan
267+
v = sc.where(v == 0, sc.scalar(np.nan, unit=v.unit), v)
268268
inv_v = 1.0 / v
269-
r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0)
270-
v_avg = 1 / np.nansum(inv_v, axis=0)
269+
r_avg = sc.nansum(r * inv_v, dim='curves') / sc.nansum(inv_v, dim='curves')
270+
v_avg = 1 / sc.nansum(inv_v, dim='curves')
271271

272272
out = sc.DataArray(
273273
data=sc.array(
274274
dims='Q',
275-
values=r_avg,
276-
variances=v_avg,
275+
values=r_avg.values,
276+
variances=v_avg.values,
277277
unit=next(iter(curves)).data.unit,
278278
),
279279
coords={'Q': q_bin_edges},
@@ -286,18 +286,15 @@ def combine_curves(
286286
q_res = (
287287
sc.DataArray(
288288
data=c.coords.get(
289-
'Q_resolution', sc.scalar(float('nan')) * sc.values(c.data.copy())
289+
'Q_resolution', sc.full_like(c.coords['Q'], value=np.nan)
290290
),
291291
coords={'Q': c.coords['Q']},
292292
)
293293
for c in curves
294294
)
295-
qs = _interpolate_on_qgrid(q_res, q_bin_edges).values
296-
qs_avg = np.nansum(qs * inv_v, axis=0) / np.nansum(
297-
~np.isnan(qs) * inv_v, axis=0
298-
)
299-
out.coords['Q_resolution'] = sc.array(
300-
dims='Q', values=qs_avg, unit=next(iter(curves)).coords['Q_resolution'].unit
295+
qs = _interpolate_on_qgrid(q_res, q_bin_edges)
296+
out.coords['Q_resolution'] = sc.nansum(qs * inv_v, dim='curves') / sc.nansum(
297+
sc.where(sc.isnan(qs), sc.scalar(0.0, unit=inv_v.unit), inv_v), dim='curves'
301298
)
302299
return out
303300

0 commit comments

Comments
 (0)