Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions ml_metrics/_src/aggregates/rolling_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,35 @@ def _add_samples_to_reservoir(self, samples: list[Any], n: int):
self._logw += np.log(self._rng.uniform(low=_EPSNEG)) / self.max_size
self._num_samples_reviewed += n

def _merge_reservoirs(self, other: FixedSizeSample) -> list[Any]:
# TODO: b/370053191 - For efficiency, sample from the combined reservoir
# in one-shot.
result = []
num_samples_orig = self._num_samples_reviewed
reservoir_new, num_samples_new = other.reservoir, other.num_samples_reviewed
while len(result) < self.max_size and num_samples_orig + num_samples_new:
thr_from_orig = num_samples_orig / (num_samples_orig + num_samples_new)
if self._rng.uniform() < thr_from_orig:
sample = self._reservoir.pop(self._rng.integers(len(self._reservoir)))
num_samples_orig -= 1
else:
sample = reservoir_new.pop(self._rng.integers(len(reservoir_new)))
num_samples_new -= 1
result.append(sample)
return result
def _merge_reservoirs(
self, max_size: int, num_samples_reviewed: int, reservoir: list[Any]
) -> list[Any]:

# Use num of samples reviewed to perform a weighted random choice for which
# reservoir to add to the new reservoir first.
if self._rng.uniform() < self._num_samples_reviewed / (
self._num_samples_reviewed + num_samples_reviewed
):
first, second = self._reservoir, reservoir
else:
first, second = reservoir, self._reservoir

# Fill new reservoir with samples from the first reservoir.
if len(first) == self.max_size:
return first

if len(first) > self.max_size:
return self._rng.choice(first, self.max_size, replace=False)

# Finish filling the new reservoir with samples from the second reservoir.
if len(second) <= max_size - len(first):
return list(first) + list(second)

return list(first) + list(
self._rng.choice(
second, min(len(second), self.max_size - len(first)), replace=False
)
)

def add(self, inputs: types.NumbersT):
self._add_samples_to_reservoir(inputs, n=len(inputs))
Expand All @@ -154,7 +167,10 @@ def merge(self, other: FixedSizeSample):
'The seeds of the two samplers must be equal, but recieved'
f' self.seed={self.seed} and other.seed={other.seed}.'
)
self._reservoir = self._merge_reservoirs(other)

self._reservoir = self._merge_reservoirs(
other.max_size, other.num_samples_reviewed, other.reservoir
)
self._num_samples_reviewed += other.num_samples_reviewed
self._logw += other.logw

Expand Down
60 changes: 47 additions & 13 deletions ml_metrics/_src/aggregates/rolling_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,26 +303,26 @@ def _assert_fixed_size_samples_equal(self, actual, expected):
reservoir_other=[6, 7, 8, 9, 10],
num_samples_original=5,
num_samples_other=5,
expected_reservoir=(2, 1, 9, 10, 7),
expected_reservoir=(1, 2, 3, 4, 5),
expected_num_samples_reviewed=10,
),
dict(
testcase_name='only_original_reservoir_full',
reservoir_original=[1, 2, 3, 4, 5],
reservoir_other=[6, 7],
num_samples_original=5,
num_samples_other=2,
expected_reservoir=(2, 1, 7, 5, 6),
expected_num_samples_reviewed=7,
num_samples_other=20,
expected_reservoir=(1, 4, 5, 6, 7),
expected_num_samples_reviewed=25,
),
dict(
testcase_name='only_other_reservoir_full',
reservoir_original=[1, 2, 3],
reservoir_other=[6, 7, 8, 9, 10],
num_samples_original=3,
num_samples_original=20,
num_samples_other=5,
expected_reservoir=(1, 2, 9, 10, 7),
expected_num_samples_reviewed=8,
expected_reservoir=(1, 2, 3, 6, 7),
expected_num_samples_reviewed=25,
),
dict(
testcase_name='neither_reservoir_full',
Expand All @@ -347,9 +347,9 @@ def _assert_fixed_size_samples_equal(self, actual, expected):
reservoir_original=[1, 2, 3, 4, 5],
reservoir_other=[6, 7, 8, 9, 10],
num_samples_original=11,
num_samples_other=7,
expected_reservoir=(2, 1, 9, 10, 7),
expected_num_samples_reviewed=18,
num_samples_other=100,
expected_reservoir=(6, 7, 8, 9, 10),
expected_num_samples_reviewed=111,
),
)
def test_fixed_size_sample_merge(
Expand Down Expand Up @@ -389,13 +389,13 @@ def test_fixed_size_sample_merge_different_max_sizes(self):
max_size=5,
seed=0,
_reservoir=res_1,
_num_samples_reviewed=10,
_num_samples_reviewed=5,
)
other = rolling_stats.FixedSizeSample(
max_size=10, seed=0, _reservoir=res_2, _num_samples_reviewed=10
max_size=10, seed=0, _reservoir=res_2, _num_samples_reviewed=20
)
sampler.merge(other)
expected_result = (2, 1, 70, 100, 60)
expected_result = (20, 10, 90, 80, 100)
np.testing.assert_array_equal(sampler.result(), expected_result)

def test_fixed_size_sample_merge_smaller_samples_than_max_size(self):
Expand Down Expand Up @@ -570,6 +570,40 @@ def test_sampling_merge_uniformness(self):
actual_counter /= num_runs
np.testing.assert_array_less(actual_counter - max_size / max_range, 0.03)

def test_sampling_merge_uniformness_weighted_reservoirs(self):
max_range, max_size = 100, 10
actual_counter = np.zeros(max_range)
num_runs = 1000
testing_buckets = 10

for i in range(num_runs):
sampler = rolling_stats.FixedSizeSample(max_size=max_size, seed=i)
for batch in mit.batched(np.arange(max_range), 9):
other = rolling_stats.FixedSizeSample(max_size=max_size, seed=i)
other.add(batch)

# Makes every choice between the two reservoirs weighted equally.
if sampler._num_samples_reviewed != 0:
other._num_samples_reviewed = sampler._num_samples_reviewed

sampler.merge(other)

for v in sampler.result():
actual_counter[v] += 1

actual_counter /= num_runs

# We expect the last bucket to show up with probability of 1/2, the second
# last bucket to show up with probability of 1/2**2, and this pattern to
# continue.
for i in range(testing_buckets):
np.testing.assert_allclose(
actual_counter[i : i + testing_buckets],
1 / (2 ** (testing_buckets - i)),
rtol=0.99,
atol=0.005,
)

def test_as_agg_fn(self):
sampler = rolling_stats.FixedSizeSample(max_size=5, seed=0)
sampler_agg_fn = sampler.as_agg_fn()
Expand Down