Skip to content

Commit 9e29340

Browse files
committed
[python][UHI] Update the tests with the expected slicing logic
1 parent 57206da commit 9e29340

File tree

1 file changed

+56
-29
lines changed

1 file changed

+56
-29
lines changed

bindings/pyroot/pythonizations/test/uhi_indexing.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44

55
import pytest
66
import ROOT
7-
from ROOT._pythonization._uhi import (
8-
_get_axis,
9-
_get_processed_slices,
10-
_get_slice_indices,
11-
_shape,
12-
)
7+
from ROOT._pythonization._uhi import _get_axis, _get_processed_slices, _overflow, _shape, _underflow
138
from ROOT.uhi import loc, overflow, rebin, sum, underflow
149

1510

@@ -32,6 +27,12 @@ def _iterate_bins(hist):
3227
yield tuple(filter(None, (i, j, k)))
3328

3429

30+
def _get_slice_indices(slices):
31+
import numpy as np
32+
33+
grids = np.meshgrid(*slices, indexing="ij")
34+
return np.array(grids).reshape(len(slices), -1).T
35+
3536
class TestTH1Indexing:
3637
def test_access_with_bin_number(self, hist_setup):
3738
for index in [0, 8]:
@@ -52,7 +53,7 @@ def test_access_flow_bins(self, hist_setup):
5253

5354
def test_access_with_len(self, hist_setup):
5455
len_indices = (len,) * hist_setup.GetDimension()
55-
bin_counts = (_get_axis(hist_setup, i).GetNbins() for i in range(hist_setup.GetDimension()))
56+
bin_counts = (_get_axis(hist_setup, i).GetNbins() + 1 for i in range(hist_setup.GetDimension()))
5657
assert hist_setup[len_indices] == hist_setup.GetBinContent(*bin_counts)
5758

5859
def test_access_with_ellipsis(self, hist_setup):
@@ -110,17 +111,34 @@ def test_setting_with_scalar(self, hist_setup):
110111

111112
def _test_slices_match(self, hist_setup, slice_ranges, processed_slices):
112113
dim = hist_setup.GetDimension()
113-
slices, _, _ = _get_processed_slices(hist_setup, processed_slices[dim])
114+
slices, _ = _get_processed_slices(hist_setup, processed_slices[dim])
114115
expected_indices = _get_slice_indices(slices)
115116
sliced_hist = hist_setup[tuple(slice_ranges[dim])]
116117

117118
for bin_indices in expected_indices:
118119
bin_indices = tuple(map(int, bin_indices))
119-
assert sliced_hist.GetBinContent(*bin_indices) == hist_setup.GetBinContent(*bin_indices)
120-
121-
for bin_indices in _iterate_bins(hist_setup):
122-
if list(bin_indices) not in expected_indices.tolist():
123-
assert sliced_hist.GetBinContent(*bin_indices) == 0
120+
shifted_indices = []
121+
is_flow_bin = False
122+
for i, idx in enumerate(bin_indices):
123+
shift = slice_ranges[dim][i].start
124+
if callable(shift):
125+
shift = shift(hist_setup, i)
126+
elif shift is None:
127+
shift = 1
128+
else:
129+
shift += 1
130+
131+
shifted_idx = idx - shift + 1
132+
if shifted_idx <= 0 or shifted_idx == _overflow(hist_setup, i):
133+
is_flow_bin = True
134+
break
135+
136+
shifted_indices.append(shifted_idx)
137+
138+
if is_flow_bin:
139+
continue
140+
141+
assert sliced_hist.GetBinContent(*tuple(shifted_indices)) == hist_setup.GetBinContent(*bin_indices)
124142

125143
def test_slicing_with_endpoints(self, hist_setup):
126144
if _special_setting(hist_setup):
@@ -144,13 +162,13 @@ def test_slicing_without_endpoints(self, hist_setup):
144162

145163
processed_slices = {
146164
1: [slice(0, 8)],
147-
2: [slice(0, 8), slice(4, 11)],
148-
3: [slice(0, 8), slice(4, 11), slice(3, 6)],
165+
2: [slice(0, 8), slice(0, 8)],
166+
3: [slice(0, 8), slice(0, 8), slice(3, 6)],
149167
}
150168
slice_ranges = {
151169
1: [slice(None, 7)],
152-
2: [slice(None, 7), slice(3, None)],
153-
3: [slice(None, 7), slice(3, None), slice(2, 5)],
170+
2: [slice(None, 7), slice(None, 7)],
171+
3: [slice(None, 7), slice(None, 7), slice(2, 5)],
154172
}
155173
self._test_slices_match(hist_setup, slice_ranges, processed_slices)
156174

@@ -160,17 +178,17 @@ def test_slicing_with_data_coordinates(self, hist_setup):
160178

161179
processed_slices = {
162180
1: [slice(hist_setup.FindBin(2), 11)],
163-
2: [slice(hist_setup.FindBin(2), 11), slice(hist_setup.FindBin(3), 11)],
181+
2: [slice(hist_setup.FindBin(2)-1, 11), slice(2, 11)],
164182
3: [
165183
slice(hist_setup.FindBin(2), 11),
166-
slice(hist_setup.FindBin(3), 11),
167-
slice(hist_setup.FindBin(1.5), 11),
184+
slice(2, 11),
185+
slice(2, 11),
168186
],
169187
}
170188
slice_ranges = {
171189
1: [slice(loc(2), None)],
172-
2: [slice(loc(2), None), slice(loc(3), None)],
173-
3: [slice(loc(2), None), slice(loc(3), None), slice(loc(1.5), None)],
190+
2: [slice(loc(2), None), slice(3, None)],
191+
3: [slice(loc(2), None), slice(3, None), slice(3, None)],
174192
}
175193
self._test_slices_match(hist_setup, slice_ranges, processed_slices)
176194

@@ -191,7 +209,10 @@ def test_slicing_over_everything_with_action_sum(self, hist_setup):
191209
dim = hist_setup.GetDimension()
192210

193211
if dim == 1:
194-
integral = hist_setup[::sum]
212+
full_integral = hist_setup[::sum]
213+
assert full_integral == hist_setup.Integral(_underflow(hist_setup, 0), _overflow(hist_setup, 0))
214+
215+
integral = hist_setup[0:len:sum]
195216
assert integral == hist_setup.Integral()
196217

197218
if dim == 2:
@@ -225,7 +246,7 @@ def test_slicing_with_action_rebin_and_sum(self, hist_setup):
225246
if dim == 1:
226247
sliced_hist_rebin = hist_setup[5 : 9 : rebin(2)]
227248
assert isinstance(sliced_hist_rebin, ROOT.TH1)
228-
assert sliced_hist_rebin.GetNbinsX() == hist_setup.GetNbinsX() // 2
249+
assert sliced_hist_rebin.GetNbinsX() == 2
229250

230251
sliced_hist_sum = hist_setup[5:9:sum]
231252
assert isinstance(sliced_hist_sum, float)
@@ -237,10 +258,10 @@ def test_slicing_with_action_rebin_and_sum(self, hist_setup):
237258
assert sliced_hist.GetNbinsX() == hist_setup.GetNbinsX() // 2
238259

239260
if dim == 3:
240-
sliced_hist = hist_setup[:: rebin(2), ::sum, 5 : 9 : rebin(3)]
261+
sliced_hist = hist_setup[:: rebin(2), ::sum, 3 : 9 : rebin(3)]
241262
assert isinstance(sliced_hist, ROOT.TH2)
242263
assert sliced_hist.GetNbinsX() == hist_setup.GetNbinsX() // 2
243-
assert sliced_hist.GetNbinsY() == hist_setup.GetNbinsZ() // 3
264+
assert sliced_hist.GetNbinsY() == 2
244265

245266
def test_slicing_with_dict_syntax(self, hist_setup):
246267
if _special_setting(hist_setup):
@@ -262,16 +283,22 @@ def test_integral_full_slice(self, hist_setup):
262283
assert hist_setup.Integral() == pytest.approx(sliced_hist.Integral(), rel=10e-6)
263284

264285
def test_statistics_slice(self, hist_setup):
265-
if _special_setting(hist_setup):
286+
if _special_setting(hist_setup) or isinstance(hist_setup, (ROOT.TH1C, ROOT.TH2C, ROOT.TH3C)):
266287
pytest.skip("Setting cannot be tested here")
267288

289+
# Check if slicing over everything preserves the statistics
290+
sliced_hist_full = hist_setup[...]
291+
292+
assert hist_setup.GetEffectiveEntries() == sliced_hist_full.GetEffectiveEntries()
293+
assert hist_setup.Integral() == sliced_hist_full.Integral()
294+
268295
# Check if slicing over a range updates the statistics
269296
dim = hist_setup.GetDimension()
270-
[_get_axis(hist_setup, i).SetRange(3, 5) for i in range(dim)]
297+
[_get_axis(hist_setup, i).SetRange(3, 7) for i in range(dim)]
271298
slice_indices = tuple(slice(2, 7) for _ in range(dim))
272299
sliced_hist = hist_setup[slice_indices]
273300

274-
assert hist_setup.Integral() == pytest.approx(sliced_hist.Integral(), rel=1e-6)
301+
assert hist_setup.Integral() == sliced_hist.Integral()
275302
assert hist_setup.GetMean() == pytest.approx(sliced_hist.GetMean(), abs=1e-3)
276303
assert hist_setup.GetStdDev() == pytest.approx(sliced_hist.GetStdDev(), abs=1e-3)
277304
assert hist_setup.GetEffectiveEntries() == sliced_hist.GetEffectiveEntries()

0 commit comments

Comments
 (0)