Skip to content

Commit 57206da

Browse files
committed
[python][UHI] Implement TH1::Slice and TH1::SetSliceContent and adapt the uhi backend
1 parent aac51aa commit 57206da

File tree

3 files changed

+320
-128
lines changed

3 files changed

+320
-128
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi.py

Lines changed: 92 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def __call__(self, hist):
7777
return rebin_method(*self.ngroup, newname=hist.GetName())
7878

7979

80-
def _sum(hist, axis):
80+
def _sum(hist, axis, args=None):
8181
"""
8282
Represents a summation operation for histograms, which either computes the integral
83-
(1D histograms) or projects the histogram along specified axes (2D and 3D histograms).
83+
or projects the histogram along specified axes (projection is only for 2D and 3D histograms).
8484
8585
Example:
86-
ans = h[::ROOT.uhi.sum] # Compute the integral for a 1D histogram
86+
ans = h[0:len:ROOT.uhi.sum] # Compute the integral for a 1D histogram excluding flow bins
87+
ans_2 = h[::ROOT.uhi.sum, ::ROOT.uhi.sum] # Compute the integral for a 2D histogram including flow bins
8788
h_projected = h[:, ::ROOT.uhi.sum] # Project the Y axis for a 2D histogram
8889
h_projected = h[:, :, ::ROOT.uhi.sum] # Project the Z axis for a 3D histogram
8990
"""
@@ -95,19 +96,28 @@ def _invalid_axis(axis, dim):
9596
if isinstance(axis, int):
9697
axis = (axis,)
9798
if dim == 1:
98-
return hist.Integral()
99+
return hist.Integral(*args) if axis == (0,) else _invalid_axis(axis, dim)
99100
if dim == 2:
100-
return hist.ProjectionX() if axis == (0,) else hist.ProjectionY() if axis == (1,) else _invalid_axis(axis, dim)
101+
if axis == (0,):
102+
return hist.ProjectionY()
103+
elif axis == (1,):
104+
return hist.ProjectionX()
105+
elif axis == (0, 1):
106+
return hist.Integral()
107+
else:
108+
return _invalid_axis(axis, dim)
101109
if dim == 3:
102-
# It is not possible from the interface to specify the options "yx", "zy", "zx"
110+
# It is not possible from the interface to specify the options "xy", "yz", "xz"
103111
project_map = {
104-
(0,): "yz",
105-
(1,): "xz",
106-
(2,): "xy",
112+
(0,): "zy",
113+
(1,): "zx",
114+
(2,): "yx",
107115
(0, 1): "z",
108116
(0, 2): "y",
109117
(1, 2): "x",
110118
}
119+
if axis == (0, 1, 2):
120+
return hist.Integral()
111121
return hist.Project3D(project_map[axis]) if axis in project_map else _invalid_axis(axis, dim)
112122
raise NotImplementedError(f"Summing not implemented for {dim}D histograms")
113123

@@ -148,22 +158,26 @@ def _get_axis_len(self, axis, include_flow_bins=False):
148158
return _get_axis(self, axis).GetNbins() + (2 if include_flow_bins else 0)
149159

150160

151-
def _process_index_for_axis(self, index, axis):
161+
def _process_index_for_axis(self, index, axis, include_flow_bins=False, is_slice_stop=False):
152162
"""Process an index for a histogram axis handling callables and index shifting."""
153163
if callable(index):
154164
# If the index is a `loc`, `underflow`, `overflow`, or `len`
155-
return _get_axis_len(self, axis) if index is len else index(self, axis)
165+
return _get_axis_len(self, axis) + 1 if index is len else index(self, axis)
156166

157167
if isinstance(index, int):
158168
# -1 index returns the last valid bin
159169
if index == -1:
160170
return _overflow(self, axis) - 1
171+
172+
if index == _overflow(self, axis):
173+
return index + (1 if include_flow_bins else 0)
174+
161175
# Shift the indices by 1 to align with the UHI convention,
162176
# where 0 corresponds to the first bin, unlike ROOT where 0 represents underflow and 1 is the first bin.
177+
nbins = _get_axis_len(self, axis) + (1 if is_slice_stop else 0)
163178
index = index + 1
164-
nbins = _get_axis_len(self, axis)
165179
if abs(index) > nbins:
166-
raise IndexError(f"Histogram index {index} out of range for axis {axis}")
180+
raise IndexError(f"Histogram index {index-1} out of range for axis {axis}. Valid range: (0,{nbins})")
167181
return index
168182

169183
raise index
@@ -200,14 +214,12 @@ def _compute_common_index(self, index, include_flow_bins=True):
200214
raise IndexError("Only one ellipsis is allowed in the index.")
201215

202216
if any(idx is ... for idx in index):
203-
expanded_index = []
204-
for idx in index:
205-
if idx is ...:
206-
break
207-
expanded_index.append(idx)
208-
# fill remaining dimensions with `slice(None)`
209-
expanded_index.extend([slice(None)] * (dim - len(expanded_index)))
210-
index = tuple(expanded_index)
217+
ellipsis_pos = index.index(...)
218+
index = (
219+
index[:ellipsis_pos] +
220+
(slice(None),) * (dim - len(index) + 1) +
221+
index[ellipsis_pos + 1:]
222+
)
211223

212224
if len(index) != dim:
213225
raise IndexError(f"Expected {dim} indices, got {len(index)}")
@@ -224,34 +236,40 @@ def _resolve_slice_indices(self, index, axis, include_flow_bins=True):
224236
"""Resolve slice start and stop indices for a given axis"""
225237
start, stop = index.start, index.stop
226238
start = (
227-
_process_index_for_axis(self, start, axis)
239+
_process_index_for_axis(self, start, axis, include_flow_bins)
228240
if start is not None
229241
else _underflow(self, axis) + (0 if include_flow_bins else 1)
230242
)
231243
stop = (
232-
_process_index_for_axis(self, stop, axis)
244+
_process_index_for_axis(self, stop, axis, include_flow_bins, is_slice_stop=True)
233245
if stop is not None
234246
else _overflow(self, axis) + (1 if include_flow_bins else 0)
235247
)
236248
if start < _underflow(self, axis) or stop > (_overflow(self, axis) + 1) or start > stop:
237-
raise IndexError(f"Slice indices {start, stop} out of range for axis {axis}")
249+
raise IndexError(f"Slice indices {start, stop} out of range for axis {axis}. Valid range: {_underflow(self, axis), _overflow(self, axis) + 1}")
238250
return start, stop
239251

240252

241-
def _apply_actions(hist, actions):
253+
def _apply_actions(hist, actions, index, unprocessed_index, original_hist):
242254
"""Apply rebinning or summing actions to the histogram, returns a new histogram"""
243255
if not actions or all(a is None for a in actions):
244256
return hist
245-
246-
if any(a is _sum for a in actions):
247-
sum_axes = tuple(i for i, a in enumerate(actions) if a is _sum)
248-
hist = _sum(hist, sum_axes)
257+
258+
if any(a is _sum or a is sum for a in actions):
259+
sum_axes = tuple(i for i, a in enumerate(actions) if a is _sum or a is sum)
260+
if original_hist.GetDimension() == 1:
261+
start, stop = index[0].start, index[0].stop
262+
include_oflow = True if unprocessed_index.stop is None else False
263+
args = [start, stop - (1 if not include_oflow else 0)]
264+
hist = _sum(original_hist, sum_axes, args)
265+
else:
266+
hist = _sum(hist, sum_axes)
249267

250268
if any(isinstance(a, _rebin) for a in actions):
251269
rebins = [a.ngroup if isinstance(a, _rebin) else 1 for a in actions if a is not _sum]
252270
hist = _rebin(rebins)(hist)
253271

254-
if any(a is not None and not (isinstance(a, _rebin) or a is _sum) for a in actions):
272+
if any(a is not None and not (isinstance(a, _rebin) or a is _sum or a is sum) for a in actions):
255273
raise ValueError(f"Unsupported action detected in actions {actions}")
256274

257275
return hist
@@ -261,102 +279,41 @@ def _get_processed_slices(self, index):
261279
"""Process slices and extract actions for each axis"""
262280
if len(index) != self.GetDimension():
263281
raise IndexError(f"Expected {self.GetDimension()} indices, got {len(index)}")
264-
processed_slices, out_of_range_indices, actions = [], [], [None] * self.GetDimension()
282+
processed_slices, actions = [], [None] * self.GetDimension()
265283
for axis, idx in enumerate(index):
266-
axis_bins = range(_overflow(self, axis) + 1)
267284
if isinstance(idx, slice):
268285
slice_range = range(idx.start, idx.stop)
269286
processed_slices.append(slice_range)
270-
uflow = [b for b in axis_bins if b < idx.start]
271-
oflow = [b for b in axis_bins if b >= idx.stop]
272-
out_of_range_indices.append((uflow, oflow))
273287
actions[axis] = idx.step
288+
elif isinstance(idx, int):
289+
# A single value v is like v:v+1:sum, example: h2 = h[v, a:b]
290+
processed_slices.append(range(idx, idx + 1))
291+
actions[axis] = _sum
274292
else:
275293
processed_slices.append([idx])
276294

277-
return processed_slices, out_of_range_indices, actions
278-
279-
280-
def _get_slice_indices(slices):
281-
"""
282-
This function uses numpy's meshgrid to create a grid of indices from the input slices,
283-
and reshapes the grid into a list of all possible index combinations.
284-
285-
Example:
286-
slices = [range(2), range(3)]
287-
# This represents two dimensions:
288-
# - The first dimension has indices [0, 1]
289-
# - The second dimension has indices [0, 1, 2]
290-
291-
result = _get_slice_indices(slices)
292-
# result:
293-
# [[0, 0],
294-
# [0, 1],
295-
# [0, 2],
296-
# [1, 0],
297-
# [1, 1],
298-
# [1, 2]]
299-
"""
300-
import numpy as np
301-
302-
grids = np.meshgrid(*slices, indexing="ij")
303-
return np.array(grids).reshape(len(slices), -1).T
304-
305-
306-
def _set_flow_bins(self, target_hist, out_of_range_indices):
307-
"""
308-
Accumulate content from bins outside the slice range into flow bins.
309-
"""
310-
dim = self.GetDimension()
311-
uflow_bin = tuple(_underflow(self, axis) for axis in range(dim))
312-
oflow_bin = tuple(_overflow(self, axis) for axis in range(dim))
313-
flow_sum = 0
314-
315-
for axis, (underflow_indices, overflow_indices) in enumerate(out_of_range_indices):
316-
all_axes = [range(_overflow(self, j)) for j in range(dim)]
317-
318-
def sum_bin_content(indices_list, target_bin):
319-
current_val = target_hist.GetBinContent(*target_bin)
320-
temp_axes = list(all_axes)
321-
temp_axes[axis] = indices_list
322-
for idx in _get_slice_indices(temp_axes):
323-
current_val += self.GetBinContent(*tuple(map(int, idx)))
324-
target_hist.SetBinContent(*target_bin, current_val)
325-
return current_val
326-
327-
flow_sum += sum_bin_content(underflow_indices, uflow_bin)
328-
flow_sum += sum_bin_content(overflow_indices, oflow_bin)
329-
330-
return flow_sum
295+
return processed_slices, actions
331296

332297

333-
def _slice_get(self, index):
298+
def _slice_get(self, index, unprocessed_index):
334299
"""
335300
This method creates a new histogram containing only the data from the
336301
specified slice.
337302
338303
Steps:
339304
- Process the slices and extract the actions for each axis.
340-
- Clone the original histogram and reset its contents.
341-
- Set the bin content for each index in the slice.
342-
- Update the number of entries in the cloned histogram (also updates the statistics).
305+
- Get a new sliced histogram.
343306
- Apply any rebinning or summing actions to the resulting histogram.
344307
"""
345-
processed_slices, out_of_range_indices, actions = _get_processed_slices(self, index)
346-
slice_indices = _get_slice_indices(processed_slices)
347-
with _temporarily_disable_add_directory():
348-
target_hist = self.Clone()
349-
target_hist.Reset()
350-
351-
for indices in slice_indices:
352-
indices = tuple(map(int, indices))
353-
target_hist.SetBinContent(*indices, self.GetBinContent(self.GetBin(*indices)))
308+
import ROOT
354309

355-
flow_sum = _set_flow_bins(self, target_hist, out_of_range_indices)
310+
processed_slices, actions = _get_processed_slices(self, index)
311+
start_stop = [(r.start, r.stop) for r in processed_slices]
312+
slice_args = [item for pair in start_stop for item in pair]
356313

357-
target_hist.SetEntries(target_hist.GetEffectiveEntries() + flow_sum)
314+
target_hist = ROOT.Internal.Slice(self, *slice_args)
358315

359-
return _apply_actions(target_hist, actions)
316+
return _apply_actions(target_hist, actions, index, unprocessed_index, self)
360317

361318

362319
def _slice_set(self, index, unprocessed_index, value):
@@ -367,42 +324,50 @@ def _slice_set(self, index, unprocessed_index, value):
367324
"""
368325
import numpy as np
369326

327+
import ROOT
328+
329+
if isinstance(value, (list, range)):
330+
value = np.array(value)
331+
370332
# Depending on the shape of the array provided, we can set or not the flow bins
371333
# Setting with a scalar does not set the flow bins
372-
include_flow_bins = not (
373-
(isinstance(value, np.ndarray) and value.shape == _shape(self, include_flow_bins=False)) or np.isscalar(value)
334+
include_flow_bins = (
335+
(isinstance(value, np.ndarray) and value.shape != _shape(_slice_get(self, index, unprocessed_index), include_flow_bins=False)) or np.isscalar(value)
374336
)
375337
if not include_flow_bins:
376338
index = _compute_common_index(self, unprocessed_index, include_flow_bins=False)
377339

378-
processed_slices, _, actions = _get_processed_slices(self, index)
379-
slice_indices = _get_slice_indices(processed_slices)
380-
if isinstance(value, np.ndarray):
381-
if value.size != len(slice_indices):
382-
raise ValueError(f"Expected {len(slice_indices)} bin values, got {value.size}")
383-
384-
expected_shape = tuple(len(slice_range) for slice_range in processed_slices)
385-
if value.shape != expected_shape:
386-
raise ValueError(f"Shape mismatch: expected {expected_shape}, got {value.shape}")
387-
388-
for indices, val in zip(slice_indices, value.ravel()):
389-
_setbin(self, self.GetBin(*map(int, indices)), val)
390-
elif np.isscalar(value):
391-
for indices in slice_indices:
392-
_setbin(self, self.GetBin(*map(int, indices)), value)
340+
processed_slices, actions = _get_processed_slices(self, index)
341+
start_stop = [(r.start, r.stop) for r in processed_slices]
342+
slice_shape = tuple(stop - start for start, stop in start_stop)
343+
slice_args = [item for pair in start_stop for item in pair]
344+
345+
if np.isscalar(value):
346+
value = ROOT.std.variant('std::vector<Double_t>', 'Double_t')(float(value))
393347
else:
394-
raise TypeError(f"Unsupported value type: {type(value).__name__}")
395-
396-
_apply_actions(self, actions)
348+
try:
349+
value = np.asanyarray(value)
350+
if value.size != np.prod(slice_shape):
351+
try:
352+
value = np.broadcast_to(value, slice_shape)
353+
except ValueError:
354+
raise ValueError(f"Expected {np.prod(slice_shape)} bin values, got {value.size}")
355+
value_vector = ROOT.std.vector('Double_t')(value.flatten().astype(np.float64))
356+
value = ROOT.std.variant('std::vector<Double_t>', 'Double_t')(value_vector)
357+
except AttributeError:
358+
raise TypeError(f"Unsupported value type: {type(value).__name__}")
359+
360+
ROOT.Internal.SetSliceContent(self, value, *slice_args)
361+
362+
_apply_actions(self, actions, index, unprocessed_index, self)
397363

398364

399365
def _getitem(self, index):
400366
uhi_index = _compute_common_index(self, index)
401367
if all(isinstance(i, int) for i in uhi_index):
402368
return self.GetBinContent(*uhi_index)
403-
404369
if any(isinstance(i, slice) for i in uhi_index):
405-
return _slice_get(self, uhi_index)
370+
return _slice_get(self, uhi_index, index)
406371

407372

408373
def _setitem(self, index, value):

0 commit comments

Comments
 (0)