diff --git a/ml_metrics/_src/aggregates/stats.py b/ml_metrics/_src/aggregates/stats.py index f7ba6221..6d3be5d4 100644 --- a/ml_metrics/_src/aggregates/stats.py +++ b/ml_metrics/_src/aggregates/stats.py @@ -373,6 +373,183 @@ def __str__(self): return f'count: {self.result()}' +@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) +@dataclasses.dataclass(kw_only=True, eq=True) +class Mean(chainable.CallableMetric): + """Computes the mean and variance of a batch of values.""" + + batch_score_fn: Callable[..., types.NumbersT] | None = None + _count: types.NumbersT = 0 + _mean: types.NumbersT = np.nan + _input_shape: tuple[int, ...] = () + + def as_agg_fn(self, *, nested: bool = False) -> chainable.AggregateFn: + return chainable.as_agg_fn( + self.__class__, + batch_score_fn=self.batch_score_fn if not nested else None, + nested=nested, + agg_preprocess_fn=self.batch_score_fn if nested else None, + ) + + def new(self, batch: types.NumbersT) -> types.NumbersT: + """Returns a new instance with the sufficient statistics reset and updated. + + If `batch_score_fn` is provided, it will evaluate the batch and assign a + score to each item. Subsequently, the statistics are computed based on + non-nan values within the batch. If a certain dimension in batch is all nan, + mean and variance corresponding to that dimension will be nan, count for + that dimension will be 0. + + Args: + batch: A non-vacant series of values. + + Returns: + Mean + """ + batch = np.asarray( + self.batch_score_fn(batch) if self.batch_score_fn else batch + ) + return self.__class__( + _count=np.sum(~np.isnan(batch), axis=0), + _mean=np.nanmean(batch, axis=0), + _input_shape=batch.shape if batch.size else (), + ) + + @property + def count(self) -> types.NumbersT: + return self._count + + @property + def mean(self) -> types.NumbersT: + return self._mean + + @property + def total(self) -> types.NumbersT: + return math_utils.where(self._count > 0, self._mean * self._count, 0) + + @property + def input_shape(self) -> tuple[int, ...]: + return self._input_shape + + def merge(self, other: Self) -> None: + if np.all(np.isnan(other.mean)): + return + self._input_shape = self._input_shape or other.input_shape + if other.input_shape and other.input_shape[1:] != self._input_shape[1:]: + raise ValueError( + f'Incompatible shape {other.input_shape} while the' + f' other have shape {self._input_shape}.' + ) + self._count += other.count + mean_diff = math_utils.nanadd(other.mean, -self._mean) + update = mean_diff * math_utils.safe_divide(other.count, self._count) + self._mean = math_utils.nanadd(self._mean, update) + + def result(self) -> Self: + return self.mean + + def __str__(self): + return f'mean: {self.mean}' + + +@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) +@dataclasses.dataclass(kw_only=True, eq=True) +class MeanAndVariance(Mean): + """Computes the mean and variance of a batch of values.""" + + _var: types.NumbersT = np.nan + + def new(self, batch: types.NumbersT) -> types.NumbersT: + batch = np.asarray( + self.batch_score_fn(batch) if self.batch_score_fn else batch + ) + return self.__class__( # pytype: disable=wrong-keyword-args + _count=np.sum(~np.isnan(batch), axis=0), + _mean=np.nanmean(batch, axis=0), + _var=np.nanvar(batch, axis=0), + _input_shape=batch.shape if batch.size else (), + ) + + @property + def var(self) -> types.NumbersT: + return self._var + + @property + def stddev(self) -> types.NumbersT: + return np.sqrt(self._var) + + def merge(self, other: Self) -> None: + if np.all(np.isnan(other.var)): + return + prev_mean, prev_count = np.copy(self._mean), np.copy(self._count) + super().merge(other) + if np.all(np.isnan(self._var)): + self._var = other.var + return + # Reference + # (https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups) + prev_count_ratio = math_utils.safe_divide(prev_count, self._count) + other_count_ratio = math_utils.safe_divide(other.count, self._count) + delta_mean = math_utils.nanadd(self._mean, -prev_mean) + mean_diff = math_utils.nanadd(other.mean, -self._mean) + self._var = ( + prev_count_ratio * self._var + + other_count_ratio * other.var + + prev_count_ratio * delta_mean**2 + + other_count_ratio * mean_diff**2 + ) + + def result(self) -> types.NumbersT: + return self + + def __str__(self): + return ( + f'count: {self.count}, total: {self.total}, mean: {self.mean}, ' + f'var: {self.var}, stddev: {self.stddev}' + ) + + +@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) +class Var(MeanAndVariance): + + def result(self) -> types.NumbersT: + return self.var + + def __str__(self): + return f'var: {self.var}' + + +@dataclasses.dataclass(slots=True) +class NumericStats: + """Statistics for a numeric feature.""" + + min: float = np.inf + max: float = -np.inf + num_zeros: int = 0 + mean: float = np.nan + std_dev: float = np.nan + mva: MeanAndVariance = dataclasses.field(default_factory=MeanAndVariance) + + def update(self, value: list[Any]): + self.mva.add(value) + self.min = np.minimum(self.min, np.min(value)) + self.max = np.maximum(self.max, np.max(value)) + self.num_zeros += np.sum(np.equal(value, 0)) + + def merge(self, other: Self): + if other.mva.count > 0: + self.mva.merge(other.mva) + self.min = min(self.min, other.min) + self.max = max(self.max, other.max) + self.num_zeros += other.num_zeros + + def compute_result(self): + """Computes the final statistics and returns a new NumericStats.""" + mva_res = self.mva.result() + self.mean = mva_res.mean + self.std_dev = mva_res.stddev + + @dataclasses.dataclass(slots=True) class FeatureStats: """Statistics for a single feature.""" @@ -384,18 +561,24 @@ class FeatureStats: min_num_values: int | None = None tot_num_values: int = 0 avg_num_values: float = 0.0 - - def update(self, length: int, feature_type: FeatureType | None): - self.merge( - FeatureStats( - feature_type=feature_type, - num_non_missing=1, - max_num_values=length, - min_num_values=length, - tot_num_values=length, - avg_num_values=float(length), - ) + numeric_stats: NumericStats | None = None + + def update(self, value: list[Any]): + """Updates the feature stats with a new value.""" + feature_type = _get_feature_type(value) + length = len(value) + fs = FeatureStats( + feature_type=feature_type, + num_non_missing=1, + max_num_values=length, + min_num_values=length, + tot_num_values=length, + avg_num_values=float(length), ) + if feature_type in [FeatureType.INT, FeatureType.FLOAT] and value: + fs.numeric_stats = NumericStats() + fs.numeric_stats.update(value) + self.merge(fs) def merge(self, other: Self): """Merges with other feature stats.""" @@ -421,6 +604,10 @@ def merge(self, other: Self): if self.num_non_missing else 0.0 ) + if other.numeric_stats is not None: + if self.numeric_stats is None: + self.numeric_stats = NumericStats() + self.numeric_stats.merge(other.numeric_stats) @dataclasses.dataclass(slots=True) @@ -509,7 +696,7 @@ def new( num_examples += 1 for key, value in example.items(): try: - feature_stats[key].update(len(value), _get_feature_type(value)) + feature_stats[key].update(value) except ValueError as e: if 'conflicting types' in str(e): raise ValueError(f'Feature {key} has conflicting types: {e}') from e @@ -538,155 +725,11 @@ def merge(self, other: Self) -> None: def result(self) -> TfExampleStats: for feature in self._stats.feature_stats.values(): feature.num_missing = self._stats.num_examples - feature.num_non_missing + if feature.numeric_stats is not None: + feature.numeric_stats.compute_result() return self._stats -@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) -@dataclasses.dataclass(kw_only=True, eq=True) -class Mean(chainable.CallableMetric): - """Computes the mean and variance of a batch of values.""" - - batch_score_fn: Callable[..., types.NumbersT] | None = None - _count: types.NumbersT = 0 - _mean: types.NumbersT = np.nan - _input_shape: tuple[int, ...] = () - - def as_agg_fn(self, *, nested: bool = False) -> chainable.AggregateFn: - return chainable.as_agg_fn( - self.__class__, - batch_score_fn=self.batch_score_fn if not nested else None, - nested=nested, - agg_preprocess_fn=self.batch_score_fn if nested else None, - ) - - def new(self, batch: types.NumbersT) -> types.NumbersT: - """Returns a new instance with the sufficient statistics reset and updated. - - If `batch_score_fn` is provided, it will evaluate the batch and assign a - score to each item. Subsequently, the statistics are computed based on - non-nan values within the batch. If a certain dimension in batch is all nan, - mean and variance corresponding to that dimension will be nan, count for - that dimension will be 0. - - Args: - batch: A non-vacant series of values. - - Returns: - Mean - """ - batch = np.asarray( - self.batch_score_fn(batch) if self.batch_score_fn else batch - ) - return self.__class__( - _count=np.sum(~np.isnan(batch), axis=0), - _mean=np.nanmean(batch, axis=0), - _input_shape=batch.shape if batch.size else (), - ) - - @property - def count(self) -> types.NumbersT: - return self._count - - @property - def mean(self) -> types.NumbersT: - return self._mean - - @property - def total(self) -> types.NumbersT: - return math_utils.where(self._count > 0, self._mean * self._count, 0) - - @property - def input_shape(self) -> tuple[int, ...]: - return self._input_shape - - def merge(self, other: Self) -> None: - if np.all(np.isnan(other.mean)): - return - self._input_shape = self._input_shape or other.input_shape - if other.input_shape and other.input_shape[1:] != self._input_shape[1:]: - raise ValueError( - f'Incompatible shape {other.input_shape} while the' - f' other have shape {self._input_shape}.' - ) - self._count += other.count - mean_diff = math_utils.nanadd(other.mean, -self._mean) - update = mean_diff * math_utils.safe_divide(other.count, self._count) - self._mean = math_utils.nanadd(self._mean, update) - - def result(self) -> Self: - return self.mean - - def __str__(self): - return f'mean: {self.mean}' - - -@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) -@dataclasses.dataclass(kw_only=True, eq=True) -class MeanAndVariance(Mean): - """Computes the mean and variance of a batch of values.""" - - _var: types.NumbersT = np.nan - - def new(self, batch: types.NumbersT) -> types.NumbersT: - batch = np.asarray( - self.batch_score_fn(batch) if self.batch_score_fn else batch - ) - return self.__class__( # pytype: disable=wrong-keyword-args - _count=np.sum(~np.isnan(batch), axis=0), - _mean=np.nanmean(batch, axis=0), - _var=np.nanvar(batch, axis=0), - _input_shape=batch.shape if batch.size else (), - ) - - @property - def var(self) -> types.NumbersT: - return self._var - - @property - def stddev(self) -> types.NumbersT: - return np.sqrt(self._var) - - def merge(self, other: Self) -> None: - if np.all(np.isnan(other.var)): - return - prev_mean, prev_count = np.copy(self._mean), np.copy(self._count) - super().merge(other) - if np.all(np.isnan(self._var)): - self._var = other.var - return - # Reference - # (https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups) - prev_count_ratio = math_utils.safe_divide(prev_count, self._count) - other_count_ratio = math_utils.safe_divide(other.count, self._count) - delta_mean = math_utils.nanadd(self._mean, -prev_mean) - mean_diff = math_utils.nanadd(other.mean, -self._mean) - self._var = ( - prev_count_ratio * self._var - + other_count_ratio * other.var - + prev_count_ratio * delta_mean**2 - + other_count_ratio * mean_diff**2 - ) - - def result(self) -> types.NumbersT: - return self - - def __str__(self): - return ( - f'count: {self.count}, total: {self.total}, mean: {self.mean}, ' - f'var: {self.var}, stddev: {self.stddev}' - ) - - -@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) -class Var(MeanAndVariance): - - def result(self) -> types.NumbersT: - return self.var - - def __str__(self): - return f'var: {self.var}' - - @telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS) # TODO(b/345249574): Add a preprocessing function of len per row. @dataclasses.dataclass(slots=True) diff --git a/ml_metrics/_src/aggregates/stats_test.py b/ml_metrics/_src/aggregates/stats_test.py index 1c9473e8..d48fd957 100644 --- a/ml_metrics/_src/aggregates/stats_test.py +++ b/ml_metrics/_src/aggregates/stats_test.py @@ -686,67 +686,132 @@ def test_avg_num_values_zero_non_missing(self): class TfExampleStatsAggTest(parameterized.TestCase): + def assert_feature_stats_equal(self, fs1, fs2, places=6): + self.assertEqual(fs1.feature_type, fs2.feature_type) + self.assertEqual(fs1.num_missing, fs2.num_missing) + self.assertEqual(fs1.num_non_missing, fs2.num_non_missing) + self.assertEqual(fs1.max_num_values, fs2.max_num_values) + self.assertEqual(fs1.min_num_values, fs2.min_num_values) + self.assertEqual(fs1.tot_num_values, fs2.tot_num_values) + self.assertAlmostEqual( + fs1.avg_num_values, fs2.avg_num_values, places=places + ) + if fs1.numeric_stats is None: + self.assertIsNone(fs2.numeric_stats) + else: + self.assertIsNotNone(fs2.numeric_stats) + self.assertEqual(fs1.numeric_stats.num_zeros, fs2.numeric_stats.num_zeros) + self.assertAlmostEqual( + fs1.numeric_stats.min, fs2.numeric_stats.min, places=places + ) + self.assertAlmostEqual( + fs1.numeric_stats.max, fs2.numeric_stats.max, places=places + ) + self.assertAlmostEqual( + fs1.numeric_stats.mean, fs2.numeric_stats.mean, places=places + ) + self.assertAlmostEqual( + fs1.numeric_stats.std_dev, fs2.numeric_stats.std_dev, places=places + ) + + def assert_tf_example_stats_equal(self, expected, actual, places=6): + self.assertEqual(expected.num_examples, actual.num_examples) + self.assertCountEqual( + expected.feature_stats.keys(), actual.feature_stats.keys() + ) + for key in expected.feature_stats: + with self.subTest(key): + self.assert_feature_stats_equal( + expected.feature_stats[key], + actual.feature_stats[key], + places=places, + ) + def test_single_example(self): examples = [{'a': [1], 'b': [1, 2]}] agg = stats.TfExampleStatsAgg() agg.add(examples) - self.assertEqual( - stats.TfExampleStats( - num_examples=1, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, + expected = stats.TfExampleStats( + num_examples=1, + feature_stats={ + 'a': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=0, + num_non_missing=1, + max_num_values=1, + min_num_values=1, + tot_num_values=1, + avg_num_values=1.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=1.0, + mean=1.0, + std_dev=0.0, + num_zeros=0, ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, + ), + 'b': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=0, + num_non_missing=1, + max_num_values=2, + min_num_values=2, + tot_num_values=2, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=2.0, + mean=1.5, + std_dev=0.5, + num_zeros=0, ), - }, - ), - agg.result(), + ), + }, ) + self.assert_tf_example_stats_equal(expected, agg.result()) def test_multiple_examples(self): examples = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] agg = stats.TfExampleStatsAgg() agg.add(examples) - self.assertEqual( - stats.TfExampleStats( - num_examples=2, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, + expected = stats.TfExampleStats( + num_examples=2, + feature_stats={ + 'a': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=0, + num_non_missing=2, + max_num_values=3, + min_num_values=1, + tot_num_values=4, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=3.0, + mean=1.75, + std_dev=0.8291561975800501, + num_zeros=0, ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, + ), + 'b': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=1, + num_non_missing=1, + max_num_values=2, + min_num_values=2, + tot_num_values=2, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=2.0, + mean=1.5, + std_dev=0.5, + num_zeros=0, ), - }, - ), - agg.result(), + ), + }, ) + self.assert_tf_example_stats_equal(expected, agg.result()) def test_merge(self): examples1 = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] @@ -756,111 +821,158 @@ def test_merge(self): agg2 = stats.TfExampleStatsAgg() agg2.add(examples2) agg1.merge(agg2) - self.assertEqual( - stats.TfExampleStats( - num_examples=3, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, + expected = stats.TfExampleStats( + num_examples=3, + feature_stats={ + 'a': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=1, + num_non_missing=2, + max_num_values=3, + min_num_values=1, + tot_num_values=4, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=3.0, + mean=1.75, + std_dev=0.8291561975800501, + num_zeros=0, ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=2, - max_num_values=4, - min_num_values=2, - tot_num_values=6, - avg_num_values=3.0, + ), + 'b': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=1, + num_non_missing=2, + max_num_values=4, + min_num_values=2, + tot_num_values=6, + avg_num_values=3.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=4.0, + mean=2.1666666666666665, + std_dev=1.067187372604722, + num_zeros=0, ), - 'c': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=2, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, + ), + 'c': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=2, + num_non_missing=1, + max_num_values=1, + min_num_values=1, + tot_num_values=1, + avg_num_values=1.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=1.0, + mean=1.0, + std_dev=0.0, + num_zeros=0, ), - }, - ), - agg1.result(), + ), + }, ) + self.assert_tf_example_stats_equal(expected, agg1.result()) def test_unbatched(self): examples = [{'a': [1], 'b': [1, 2]}, {'a': [1, 2, 3]}] agg = stats.TfExampleStatsAgg(batched_inputs=False) for example in examples: agg.add(example) - self.assertEqual( - stats.TfExampleStats( - num_examples=2, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=2, - max_num_values=3, - min_num_values=1, - tot_num_values=4, - avg_num_values=2.0, + expected = stats.TfExampleStats( + num_examples=2, + feature_stats={ + 'a': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=0, + num_non_missing=2, + max_num_values=3, + min_num_values=1, + tot_num_values=4, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=3.0, + mean=1.75, + std_dev=0.8291561975800501, + num_zeros=0, ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=1, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, + ), + 'b': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=1, + num_non_missing=1, + max_num_values=2, + min_num_values=2, + tot_num_values=2, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=2.0, + mean=1.5, + std_dev=0.5, + num_zeros=0, ), - }, - ), - agg.result(), + ), + }, ) + self.assert_tf_example_stats_equal(expected, agg.result()) def test_feature_types(self): examples = [{'a': [1], 'b': [1.0, 2.0], 'c': ['foo', 'bar']}] agg = stats.TfExampleStatsAgg() agg.add(examples) - self.assertEqual( - stats.TfExampleStats( - num_examples=1, - feature_stats={ - 'a': stats.FeatureStats( - feature_type=stats.FeatureType.INT, - num_missing=0, - num_non_missing=1, - max_num_values=1, - min_num_values=1, - tot_num_values=1, - avg_num_values=1.0, - ), - 'b': stats.FeatureStats( - feature_type=stats.FeatureType.FLOAT, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, + expected = stats.TfExampleStats( + num_examples=1, + feature_stats={ + 'a': stats.FeatureStats( + feature_type=stats.FeatureType.INT, + num_missing=0, + num_non_missing=1, + max_num_values=1, + min_num_values=1, + tot_num_values=1, + avg_num_values=1.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=1.0, + mean=1.0, + std_dev=0.0, + num_zeros=0, ), - 'c': stats.FeatureStats( - feature_type=stats.FeatureType.STRING, - num_missing=0, - num_non_missing=1, - max_num_values=2, - min_num_values=2, - tot_num_values=2, - avg_num_values=2.0, + ), + 'b': stats.FeatureStats( + feature_type=stats.FeatureType.FLOAT, + num_missing=0, + num_non_missing=1, + max_num_values=2, + min_num_values=2, + tot_num_values=2, + avg_num_values=2.0, + numeric_stats=stats.NumericStats( + min=1.0, + max=2.0, + mean=1.5, + std_dev=0.5, + num_zeros=0, ), - }, - ), + ), + 'c': stats.FeatureStats( + feature_type=stats.FeatureType.STRING, + num_missing=0, + num_non_missing=1, + max_num_values=2, + min_num_values=2, + tot_num_values=2, + avg_num_values=2.0, + numeric_stats=None, + ), + }, + ) + self.assert_tf_example_stats_equal( + expected, agg.result(), ) @@ -1976,8 +2088,8 @@ class TfdvTest(absltest.TestCase): def test_to_proto(self): feature_stats_instance = stats.FeatureStats() - feature_stats_instance.update(1, stats.FeatureType.INT) - feature_stats_instance.update(2, stats.FeatureType.INT) + feature_stats_instance.update([1]) + feature_stats_instance.update([1, 2]) data = stats.TfExampleStats( num_examples=2, feature_stats={