From 185435826ee3226d7b4b2fadf7450fd0212a75ef Mon Sep 17 00:00:00 2001 From: ML Metrics Team Date: Fri, 19 Dec 2025 17:16:10 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 846919736 --- ml_metrics/_src/aggregates/stats.py | 246 --------------- ml_metrics/_src/aggregates/stats_test.py | 371 ----------------------- 2 files changed, 617 deletions(-) diff --git a/ml_metrics/_src/aggregates/stats.py b/ml_metrics/_src/aggregates/stats.py index 6d3be5d4..1ab0a710 100644 --- a/ml_metrics/_src/aggregates/stats.py +++ b/ml_metrics/_src/aggregates/stats.py @@ -19,7 +19,6 @@ import collections from collections.abc import Callable, Iterable import dataclasses -import enum import itertools import math from typing import Any, Generic, Self, TypeVar @@ -30,44 +29,10 @@ from ml_metrics._src.tools.telemetry import telemetry import numpy as np -from tensorflow_metadata.proto.v0 import statistics_pb2 - _EPSNEG = np.finfo(float).epsneg _T = TypeVar('_T') -@enum.unique -class FeatureType(enum.Enum): - """Feature types.""" - - INT = 'int' - FLOAT = 'float' - STRING = 'string' - - -def _get_feature_type(value: list[Any]) -> FeatureType | None: - """Returns the feature type of the given value.""" - if not value: - return None - if isinstance(value[0], str): - return FeatureType.STRING - elif isinstance(value[0], bytes): - try: - value[0].decode('utf-8') - return FeatureType.STRING - except UnicodeDecodeError as e: - raise ValueError( - 'Unsupported bytes feature type. Feature could not be decoded as' - f' UTF-8 string: {e}' - ) from e - elif isinstance(value[0], int): - return FeatureType.INT - elif isinstance(value[0], float): - return FeatureType.FLOAT - else: - raise ValueError(f'Unsupported feature type: {type(value[0])}') - - @telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) @dataclasses.dataclass(kw_only=True) class UnboundedSampler(chainable.CallableMetric, chainable.HasAsAggFn): @@ -519,217 +484,6 @@ def __str__(self): return f'var: {self.var}' -@dataclasses.dataclass(slots=True) -class NumericStats: - """Statistics for a numeric feature.""" - - min: float = np.inf - max: float = -np.inf - num_zeros: int = 0 - mean: float = np.nan - std_dev: float = np.nan - mva: MeanAndVariance = dataclasses.field(default_factory=MeanAndVariance) - - def update(self, value: list[Any]): - self.mva.add(value) - self.min = np.minimum(self.min, np.min(value)) - self.max = np.maximum(self.max, np.max(value)) - self.num_zeros += np.sum(np.equal(value, 0)) - - def merge(self, other: Self): - if other.mva.count > 0: - self.mva.merge(other.mva) - self.min = min(self.min, other.min) - self.max = max(self.max, other.max) - self.num_zeros += other.num_zeros - - def compute_result(self): - """Computes the final statistics and returns a new NumericStats.""" - mva_res = self.mva.result() - self.mean = mva_res.mean - self.std_dev = mva_res.stddev - - -@dataclasses.dataclass(slots=True) -class FeatureStats: - """Statistics for a single feature.""" - - feature_type: FeatureType | None = None - num_missing: int = 0 - num_non_missing: int = 0 - max_num_values: int = 0 - min_num_values: int | None = None - tot_num_values: int = 0 - avg_num_values: float = 0.0 - numeric_stats: NumericStats | None = None - - def update(self, value: list[Any]): - """Updates the feature stats with a new value.""" - feature_type = _get_feature_type(value) - length = len(value) - fs = FeatureStats( - feature_type=feature_type, - num_non_missing=1, - max_num_values=length, - min_num_values=length, - tot_num_values=length, - avg_num_values=float(length), - ) - if feature_type in [FeatureType.INT, FeatureType.FLOAT] and value: - fs.numeric_stats = NumericStats() - fs.numeric_stats.update(value) - self.merge(fs) - - def merge(self, other: Self): - """Merges with other feature stats.""" - if other.feature_type is not None: - if self.feature_type is None: - self.feature_type = other.feature_type - elif self.feature_type != other.feature_type: - raise ValueError( - f'Feature has conflicting types: {self.feature_type} vs' - f' {other.feature_type}' - ) - self.num_non_missing += other.num_non_missing - self.max_num_values = max(self.max_num_values, other.max_num_values) - if other.min_num_values is None: - pass - elif self.min_num_values is None: - self.min_num_values = other.min_num_values - else: - self.min_num_values = min(self.min_num_values, other.min_num_values) - self.tot_num_values += other.tot_num_values - self.avg_num_values = ( - self.tot_num_values / self.num_non_missing - if self.num_non_missing - else 0.0 - ) - if other.numeric_stats is not None: - if self.numeric_stats is None: - self.numeric_stats = NumericStats() - self.numeric_stats.merge(other.numeric_stats) - - -@dataclasses.dataclass(slots=True) -class TfExampleStats: - """The result of TfExampleStatsAgg.""" - - num_examples: int = 0 - feature_stats: dict[str, FeatureStats] = dataclasses.field( - default_factory=dict - ) - - def to_proto(self): - """Writes the data to the sink.""" - feature_stats_proto = statistics_pb2.DatasetFeatureStatistics( - num_examples=self.num_examples - ) - for name, feature_stats in self.feature_stats.items(): - feature_name_stats = feature_stats_proto.features.add( - path={'step': [name]} - ) - feature_name_stats.num_stats.common_stats.num_missing = ( - feature_stats.num_missing - ) - feature_name_stats.num_stats.common_stats.num_non_missing = ( - feature_stats.num_non_missing - ) - feature_name_stats.num_stats.common_stats.min_num_values = ( - feature_stats.min_num_values - ) - feature_name_stats.num_stats.common_stats.max_num_values = ( - feature_stats.max_num_values - ) - feature_name_stats.num_stats.common_stats.avg_num_values = ( - feature_stats.avg_num_values - ) - feature_name_stats.num_stats.common_stats.tot_num_values = ( - feature_stats.tot_num_values - ) - if feature_stats.feature_type: - feature_name_stats.type = ( - statistics_pb2.FeatureNameStatistics.Type.Value( - feature_stats.feature_type.name - ) - ) - return statistics_pb2.DatasetFeatureStatisticsList( - datasets=[feature_stats_proto] - ) - - -@telemetry.class_monitor(category=telemetry.CATEGORY.STATS) -@dataclasses.dataclass(slots=True, kw_only=True) -class TfExampleStatsAgg(chainable.CallableMetric): - """Computes statistics on features.""" - - batched_inputs: bool = True - _stats: TfExampleStats = dataclasses.field(default_factory=TfExampleStats) - - def as_agg_fn(self) -> chainable.AggregateFn: - return chainable.as_agg_fn( - self.__class__, - batched_inputs=self.batched_inputs, - ) - - @property - def num_examples(self) -> int: - return self._stats.num_examples - - @property - def feature_stats(self) -> dict[str, FeatureStats]: - return self._stats.feature_stats - - def new( - self, inputs: dict[str, list[Any]] | list[dict[str, list[Any]]] - ) -> Self: - """Computes the sufficient statistics of a batch of inputs.""" - - # Check if the input is a single dict for when batched_inputs is False - if isinstance(inputs, dict): - inputs_list = [inputs] - else: - inputs_list = inputs - - num_examples = 0 - feature_stats = collections.defaultdict(FeatureStats) - for example in inputs_list: - num_examples += 1 - for key, value in example.items(): - try: - feature_stats[key].update(value) - except ValueError as e: - if 'conflicting types' in str(e): - raise ValueError(f'Feature {key} has conflicting types: {e}') from e - raise - return self.__class__( - batched_inputs=self.batched_inputs, - _stats=TfExampleStats( - num_examples=num_examples, - feature_stats=dict(feature_stats), - ), - ) - - def merge(self, other: Self) -> None: - self._stats.num_examples += other.num_examples - for key, value in other.feature_stats.items(): - if key in self._stats.feature_stats: - try: - self._stats.feature_stats[key].merge(value) - except ValueError as e: - if 'conflicting types' in str(e): - raise ValueError(f'Feature {key} has conflicting types: {e}') from e - raise - else: - self._stats.feature_stats[key] = value - - def result(self) -> TfExampleStats: - for feature in self._stats.feature_stats.values(): - feature.num_missing = self._stats.num_examples - feature.num_non_missing - if feature.numeric_stats is not None: - feature.numeric_stats.compute_result() - return self._stats - - @telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) # TODO(b/345249574): Add a preprocessing function of len per row. @dataclasses.dataclass(slots=True) diff --git a/ml_metrics/_src/aggregates/stats_test.py b/ml_metrics/_src/aggregates/stats_test.py index d48fd957..6eae848f 100644 --- a/ml_metrics/_src/aggregates/stats_test.py +++ b/ml_metrics/_src/aggregates/stats_test.py @@ -23,7 +23,6 @@ import numpy as np from absl.testing import absltest -from tensorflow_metadata.proto.v0 import statistics_pb2 class HistogramTest(parameterized.TestCase): @@ -659,350 +658,6 @@ def test_str(self): self.assertEqual('count: 3', str(count)) -class FeatureStatsTest(parameterized.TestCase): - - def test_avg_num_values(self): - stats1 = stats.FeatureStats( - num_missing=0, - num_non_missing=10, - max_num_values=5, - min_num_values=1, - tot_num_values=30, - avg_num_values=3.0, - ) - self.assertAlmostEqual(3.0, stats1.avg_num_values) - - def test_avg_num_values_zero_non_missing(self): - stats1 = stats.FeatureStats( - num_missing=0, - num_non_missing=0, - max_num_values=0, - min_num_values=None, - tot_num_values=0, - avg_num_values=0.0, - ) - self.assertAlmostEqual(0.0, stats1.avg_num_values) - - -class TfExampleStatsAggTest(parameterized.TestCase): - - def assert_feature_stats_equal(self, fs1, fs2, places=6): - self.assertEqual(fs1.feature_type, fs2.feature_type) - self.assertEqual(fs1.num_missing, fs2.num_missing) - self.assertEqual(fs1.num_non_missing, fs2.num_non_missing) - self.assertEqual(fs1.max_num_values, fs2.max_num_values) - self.assertEqual(fs1.min_num_values, fs2.min_num_values) - self.assertEqual(fs1.tot_num_values, fs2.tot_num_values) - self.assertAlmostEqual( - fs1.avg_num_values, fs2.avg_num_values, places=places - ) - if fs1.numeric_stats is None: - self.assertIsNone(fs2.numeric_stats) - else: - self.assertIsNotNone(fs2.numeric_stats) - self.assertEqual(fs1.numeric_stats.num_zeros, fs2.numeric_stats.num_zeros) - self.assertAlmostEqual( - fs1.numeric_stats.min, fs2.numeric_stats.min, places=places - ) - self.assertAlmostEqual( - fs1.numeric_stats.max, fs2.numeric_stats.max, places=places - ) - self.assertAlmostEqual( - fs1.numeric_stats.mean, fs2.numeric_stats.mean, places=places - ) - self.assertAlmostEqual( - fs1.numeric_stats.std_dev, fs2.numeric_stats.std_dev, places=places - ) - - def assert_tf_example_stats_equal(self, expected, actual, places=6): - self.assertEqual(expected.num_examples, actual.num_examples) - self.assertCountEqual( - expected.feature_stats.keys(), actual.feature_stats.keys() - ) - for key in expected.feature_stats: - with self.subTest(key): - self.assert_feature_stats_equal( - expected.feature_stats[key], - actual.feature_stats[key], - places=places, - ) - - def test_single_example(self): - examples = [{'a': [1], 'b': [1, 2]}] - agg = stats.TfExampleStatsAgg() - agg.add(examples) - expected = stats.TfExampleStats( - num_examples=1, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=1.0, - mean=1.0, - std_dev=0.0, - num_zeros=0, - ), - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=2.0, - mean=1.5, - std_dev=0.5, - num_zeros=0, - ), - ), - }, - ) - self.assert_tf_example_stats_equal(expected, agg.result()) - - def test_multiple_examples(self): - examples = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] - agg = stats.TfExampleStatsAgg() - agg.add(examples) - expected = stats.TfExampleStats( - num_examples=2, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=3.0, - mean=1.75, - std_dev=0.8291561975800501, - num_zeros=0, - ), - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=2.0, - mean=1.5, - std_dev=0.5, - num_zeros=0, - ), - ), - }, - ) - self.assert_tf_example_stats_equal(expected, agg.result()) - - def test_merge(self): - examples1 = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] - examples2 = [{'b': [1, 2, 3, 4], 'c': [1]}] - agg1 = stats.TfExampleStatsAgg() - agg1.add(examples1) - agg2 = stats.TfExampleStatsAgg() - agg2.add(examples2) - agg1.merge(agg2) - expected = stats.TfExampleStats( - num_examples=3, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=3.0, - mean=1.75, - std_dev=0.8291561975800501, - num_zeros=0, - ), - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=2, - max_num_values=4, - min_num_values=2, - tot_num_values=6, - avg_num_values=3.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=4.0, - mean=2.1666666666666665, - std_dev=1.067187372604722, - num_zeros=0, - ), - ), - 'c': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=2, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=1.0, - mean=1.0, - std_dev=0.0, - num_zeros=0, - ), - ), - }, - ) - self.assert_tf_example_stats_equal(expected, agg1.result()) - - def test_unbatched(self): - examples = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] - agg = stats.TfExampleStatsAgg(batched_inputs=False) - for example in examples: - agg.add(example) - expected = stats.TfExampleStats( - num_examples=2, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=3.0, - mean=1.75, - std_dev=0.8291561975800501, - num_zeros=0, - ), - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=2.0, - mean=1.5, - std_dev=0.5, - num_zeros=0, - ), - ), - }, - ) - self.assert_tf_example_stats_equal(expected, agg.result()) - - def test_feature_types(self): - examples = [{'a': [1], 'b': [1.0, 2.0], 'c': ['foo', 'bar']}] - agg = stats.TfExampleStatsAgg() - agg.add(examples) - expected = stats.TfExampleStats( - num_examples=1, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=1.0, - mean=1.0, - std_dev=0.0, - num_zeros=0, - ), - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.FLOAT, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, - numeric_stats=stats.NumericStats( - min=1.0, - max=2.0, - mean=1.5, - std_dev=0.5, - num_zeros=0, - ), - ), - 'c': stats.FeatureStats( - feature_type=stats.FeatureType.STRING, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, - numeric_stats=None, - ), - }, - ) - self.assert_tf_example_stats_equal( - expected, - agg.result(), - ) - - def test_conflicting_types_in_batch(self): - examples = [{'a': [1]}, {'a': ['foo']}] - agg = stats.TfExampleStatsAgg() - with self.assertRaisesRegex(ValueError, 'Feature a has conflicting types'): - agg.add(examples) - - def test_conflicting_types_in_merge(self): - examples1 = [{'a': [1]}] - examples2 = [{'a': ['foo']}] - agg1 = stats.TfExampleStatsAgg() - agg1.add(examples1) - agg2 = stats.TfExampleStatsAgg() - agg2.add(examples2) - with self.assertRaisesRegex(ValueError, 'Feature a has conflicting types'): - agg1.merge(agg2) - - def test_bytes_feature_type_invalid_utf8_error(self): - examples = [{'a': [b'\xff']}] - agg = stats.TfExampleStatsAgg() - with self.assertRaisesRegex( - ValueError, - 'Unsupported bytes feature type. Feature could not be decoded as UTF-8' - ' string', - ): - agg.add(examples) - - class MeanAndVarianceTest(parameterized.TestCase): def assertDataclassAlmostEqual( @@ -2084,31 +1739,5 @@ def test_symmetric_prediction_difference_asserts_with_invalid_input(self): metric.add(x, y) -class TfdvTest(absltest.TestCase): - - def test_to_proto(self): - feature_stats_instance = stats.FeatureStats() - feature_stats_instance.update([1]) - feature_stats_instance.update([1, 2]) - data = stats.TfExampleStats( - num_examples=2, - feature_stats={ - 'feature1': feature_stats_instance, - }, - ) - stats_list = data.to_proto() - self.assertLen(stats_list.datasets, 1) - self.assertEqual(stats_list.datasets[0].num_examples, 2) - self.assertLen(stats_list.datasets[0].features, 1) - feature = stats_list.datasets[0].features[0] - self.assertEqual(feature.path.step[0], 'feature1') - self.assertEqual(feature.num_stats.common_stats.num_non_missing, 2) - self.assertEqual(feature.num_stats.common_stats.min_num_values, 1) - self.assertEqual(feature.num_stats.common_stats.max_num_values, 2) - self.assertEqual(feature.num_stats.common_stats.avg_num_values, 1.5) - self.assertEqual(feature.num_stats.common_stats.tot_num_values, 3) - self.assertEqual(feature.type, statistics_pb2.FeatureNameStatistics.INT) - - if __name__ == '__main__': absltest.main()