diff --git a/ml_metrics/_src/aggregates/rolling_stats.py b/ml_metrics/_src/aggregates/rolling_stats.py index 319047e4..2cacec7e 100644 --- a/ml_metrics/_src/aggregates/rolling_stats.py +++ b/ml_metrics/_src/aggregates/rolling_stats.py @@ -128,22 +128,35 @@ def _add_samples_to_reservoir(self, samples: list[Any], n: int): self._logw += np.log(self._rng.uniform(low=_EPSNEG)) / self.max_size self._num_samples_reviewed += n - def _merge_reservoirs(self, other: FixedSizeSample) -> list[Any]: - # TODO: b/370053191 - For efficiency, sample from the combined reservoir - # in one-shot. - result = [] - num_samples_orig = self._num_samples_reviewed - reservoir_new, num_samples_new = other.reservoir, other.num_samples_reviewed - while len(result) < self.max_size and num_samples_orig + num_samples_new: - thr_from_orig = num_samples_orig / (num_samples_orig + num_samples_new) - if self._rng.uniform() < thr_from_orig: - sample = self._reservoir.pop(self._rng.integers(len(self._reservoir))) - num_samples_orig -= 1 - else: - sample = reservoir_new.pop(self._rng.integers(len(reservoir_new))) - num_samples_new -= 1 - result.append(sample) - return result + def _merge_reservoirs( + self, max_size: int, num_samples_reviewed: int, reservoir: list[Any] + ) -> list[Any]: + + # Use num of samples reviewed to perform a weighted random choice for which + # reservoir to add to the new reservoir first. + if self._rng.uniform() < self._num_samples_reviewed / ( + self._num_samples_reviewed + num_samples_reviewed + ): + first, second = self._reservoir, reservoir + else: + first, second = reservoir, self._reservoir + + # Fill new reservoir with samples from the first reservoir. + if len(first) == self.max_size: + return first + + if len(first) > self.max_size: + return self._rng.choice(first, self.max_size, replace=False) + + # Finish filling the new reservoir with samples from the second reservoir. + if len(second) <= max_size - len(first): + return list(first) + list(second) + + return list(first) + list( + self._rng.choice( + second, min(len(second), self.max_size - len(first)), replace=False + ) + ) def add(self, inputs: types.NumbersT): self._add_samples_to_reservoir(inputs, n=len(inputs)) @@ -154,7 +167,10 @@ def merge(self, other: FixedSizeSample): 'The seeds of the two samplers must be equal, but recieved' f' self.seed={self.seed} and other.seed={other.seed}.' ) - self._reservoir = self._merge_reservoirs(other) + + self._reservoir = self._merge_reservoirs( + other.max_size, other.num_samples_reviewed, other.reservoir + ) self._num_samples_reviewed += other.num_samples_reviewed self._logw += other.logw diff --git a/ml_metrics/_src/aggregates/rolling_stats_test.py b/ml_metrics/_src/aggregates/rolling_stats_test.py index e23dcc5c..c377ed20 100644 --- a/ml_metrics/_src/aggregates/rolling_stats_test.py +++ b/ml_metrics/_src/aggregates/rolling_stats_test.py @@ -303,7 +303,7 @@ def _assert_fixed_size_samples_equal(self, actual, expected): reservoir_other=[6, 7, 8, 9, 10], num_samples_original=5, num_samples_other=5, - expected_reservoir=(2, 1, 9, 10, 7), + expected_reservoir=(1, 2, 3, 4, 5), expected_num_samples_reviewed=10, ), dict( @@ -311,18 +311,18 @@ def _assert_fixed_size_samples_equal(self, actual, expected): reservoir_original=[1, 2, 3, 4, 5], reservoir_other=[6, 7], num_samples_original=5, - num_samples_other=2, - expected_reservoir=(2, 1, 7, 5, 6), - expected_num_samples_reviewed=7, + num_samples_other=20, + expected_reservoir=(1, 4, 5, 6, 7), + expected_num_samples_reviewed=25, ), dict( testcase_name='only_other_reservoir_full', reservoir_original=[1, 2, 3], reservoir_other=[6, 7, 8, 9, 10], - num_samples_original=3, + num_samples_original=20, num_samples_other=5, - expected_reservoir=(1, 2, 9, 10, 7), - expected_num_samples_reviewed=8, + expected_reservoir=(1, 2, 3, 6, 7), + expected_num_samples_reviewed=25, ), dict( testcase_name='neither_reservoir_full', @@ -347,9 +347,9 @@ def _assert_fixed_size_samples_equal(self, actual, expected): reservoir_original=[1, 2, 3, 4, 5], reservoir_other=[6, 7, 8, 9, 10], num_samples_original=11, - num_samples_other=7, - expected_reservoir=(2, 1, 9, 10, 7), - expected_num_samples_reviewed=18, + num_samples_other=100, + expected_reservoir=(6, 7, 8, 9, 10), + expected_num_samples_reviewed=111, ), ) def test_fixed_size_sample_merge( @@ -389,13 +389,13 @@ def test_fixed_size_sample_merge_different_max_sizes(self): max_size=5, seed=0, _reservoir=res_1, - _num_samples_reviewed=10, + _num_samples_reviewed=5, ) other = rolling_stats.FixedSizeSample( - max_size=10, seed=0, _reservoir=res_2, _num_samples_reviewed=10 + max_size=10, seed=0, _reservoir=res_2, _num_samples_reviewed=20 ) sampler.merge(other) - expected_result = (2, 1, 70, 100, 60) + expected_result = (20, 10, 90, 80, 100) np.testing.assert_array_equal(sampler.result(), expected_result) def test_fixed_size_sample_merge_smaller_samples_than_max_size(self): @@ -570,6 +570,40 @@ def test_sampling_merge_uniformness(self): actual_counter /= num_runs np.testing.assert_array_less(actual_counter - max_size / max_range, 0.03) + def test_sampling_merge_uniformness_weighted_reservoirs(self): + max_range, max_size = 100, 10 + actual_counter = np.zeros(max_range) + num_runs = 1000 + testing_buckets = 10 + + for i in range(num_runs): + sampler = rolling_stats.FixedSizeSample(max_size=max_size, seed=i) + for batch in mit.batched(np.arange(max_range), 9): + other = rolling_stats.FixedSizeSample(max_size=max_size, seed=i) + other.add(batch) + + # Makes every choice between the two reservoirs weighted equally. + if sampler._num_samples_reviewed != 0: + other._num_samples_reviewed = sampler._num_samples_reviewed + + sampler.merge(other) + + for v in sampler.result(): + actual_counter[v] += 1 + + actual_counter /= num_runs + + # We expect the last bucket to show up with probability of 1/2, the second + # last bucket to show up with probability of 1/2**2, and this pattern to + # continue. + for i in range(testing_buckets): + np.testing.assert_allclose( + actual_counter[i : i + testing_buckets], + 1 / (2 ** (testing_buckets - i)), + rtol=0.99, + atol=0.005, + ) + def test_as_agg_fn(self): sampler = rolling_stats.FixedSizeSample(max_size=5, seed=0) sampler_agg_fn = sampler.as_agg_fn()