Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 0 additions & 246 deletions ml_metrics/_src/aggregates/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import collections
from collections.abc import Callable, Iterable
import dataclasses
import enum
import itertools
import math
from typing import Any, Generic, Self, TypeVar
Expand All @@ -30,44 +29,10 @@
from ml_metrics._src.tools.telemetry import telemetry
import numpy as np

from tensorflow_metadata.proto.v0 import statistics_pb2

_EPSNEG = np.finfo(float).epsneg
_T = TypeVar('_T')


@enum.unique
class FeatureType(enum.Enum):
"""Feature types."""

INT = 'int'
FLOAT = 'float'
STRING = 'string'


def _get_feature_type(value: list[Any]) -> FeatureType | None:
"""Returns the feature type of the given value."""
if not value:
return None
if isinstance(value[0], str):
return FeatureType.STRING
elif isinstance(value[0], bytes):
try:
value[0].decode('utf-8')
return FeatureType.STRING
except UnicodeDecodeError as e:
raise ValueError(
'Unsupported bytes feature type. Feature could not be decoded as'
f' UTF-8 string: {e}'
) from e
elif isinstance(value[0], int):
return FeatureType.INT
elif isinstance(value[0], float):
return FeatureType.FLOAT
else:
raise ValueError(f'Unsupported feature type: {type(value[0])}')


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True)
class UnboundedSampler(chainable.CallableMetric, chainable.HasAsAggFn):
Expand Down Expand Up @@ -519,217 +484,6 @@ def __str__(self):
return f'var: {self.var}'


@dataclasses.dataclass(slots=True)
class NumericStats:
"""Statistics for a numeric feature."""

min: float = np.inf
max: float = -np.inf
num_zeros: int = 0
mean: float = np.nan
std_dev: float = np.nan
mva: MeanAndVariance = dataclasses.field(default_factory=MeanAndVariance)

def update(self, value: list[Any]):
self.mva.add(value)
self.min = np.minimum(self.min, np.min(value))
self.max = np.maximum(self.max, np.max(value))
self.num_zeros += np.sum(np.equal(value, 0))

def merge(self, other: Self):
if other.mva.count > 0:
self.mva.merge(other.mva)
self.min = min(self.min, other.min)
self.max = max(self.max, other.max)
self.num_zeros += other.num_zeros

def compute_result(self):
"""Computes the final statistics and returns a new NumericStats."""
mva_res = self.mva.result()
self.mean = mva_res.mean
self.std_dev = mva_res.stddev


@dataclasses.dataclass(slots=True)
class FeatureStats:
"""Statistics for a single feature."""

feature_type: FeatureType | None = None
num_missing: int = 0
num_non_missing: int = 0
max_num_values: int = 0
min_num_values: int | None = None
tot_num_values: int = 0
avg_num_values: float = 0.0
numeric_stats: NumericStats | None = None

def update(self, value: list[Any]):
"""Updates the feature stats with a new value."""
feature_type = _get_feature_type(value)
length = len(value)
fs = FeatureStats(
feature_type=feature_type,
num_non_missing=1,
max_num_values=length,
min_num_values=length,
tot_num_values=length,
avg_num_values=float(length),
)
if feature_type in [FeatureType.INT, FeatureType.FLOAT] and value:
fs.numeric_stats = NumericStats()
fs.numeric_stats.update(value)
self.merge(fs)

def merge(self, other: Self):
"""Merges with other feature stats."""
if other.feature_type is not None:
if self.feature_type is None:
self.feature_type = other.feature_type
elif self.feature_type != other.feature_type:
raise ValueError(
f'Feature has conflicting types: {self.feature_type} vs'
f' {other.feature_type}'
)
self.num_non_missing += other.num_non_missing
self.max_num_values = max(self.max_num_values, other.max_num_values)
if other.min_num_values is None:
pass
elif self.min_num_values is None:
self.min_num_values = other.min_num_values
else:
self.min_num_values = min(self.min_num_values, other.min_num_values)
self.tot_num_values += other.tot_num_values
self.avg_num_values = (
self.tot_num_values / self.num_non_missing
if self.num_non_missing
else 0.0
)
if other.numeric_stats is not None:
if self.numeric_stats is None:
self.numeric_stats = NumericStats()
self.numeric_stats.merge(other.numeric_stats)


@dataclasses.dataclass(slots=True)
class TfExampleStats:
"""The result of TfExampleStatsAgg."""

num_examples: int = 0
feature_stats: dict[str, FeatureStats] = dataclasses.field(
default_factory=dict
)

def to_proto(self):
"""Writes the data to the sink."""
feature_stats_proto = statistics_pb2.DatasetFeatureStatistics(
num_examples=self.num_examples
)
for name, feature_stats in self.feature_stats.items():
feature_name_stats = feature_stats_proto.features.add(
path={'step': [name]}
)
feature_name_stats.num_stats.common_stats.num_missing = (
feature_stats.num_missing
)
feature_name_stats.num_stats.common_stats.num_non_missing = (
feature_stats.num_non_missing
)
feature_name_stats.num_stats.common_stats.min_num_values = (
feature_stats.min_num_values
)
feature_name_stats.num_stats.common_stats.max_num_values = (
feature_stats.max_num_values
)
feature_name_stats.num_stats.common_stats.avg_num_values = (
feature_stats.avg_num_values
)
feature_name_stats.num_stats.common_stats.tot_num_values = (
feature_stats.tot_num_values
)
if feature_stats.feature_type:
feature_name_stats.type = (
statistics_pb2.FeatureNameStatistics.Type.Value(
feature_stats.feature_type.name
)
)
return statistics_pb2.DatasetFeatureStatisticsList(
datasets=[feature_stats_proto]
)


@telemetry.class_monitor(category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(slots=True, kw_only=True)
class TfExampleStatsAgg(chainable.CallableMetric):
"""Computes statistics on features."""

batched_inputs: bool = True
_stats: TfExampleStats = dataclasses.field(default_factory=TfExampleStats)

def as_agg_fn(self) -> chainable.AggregateFn:
return chainable.as_agg_fn(
self.__class__,
batched_inputs=self.batched_inputs,
)

@property
def num_examples(self) -> int:
return self._stats.num_examples

@property
def feature_stats(self) -> dict[str, FeatureStats]:
return self._stats.feature_stats

def new(
self, inputs: dict[str, list[Any]] | list[dict[str, list[Any]]]
) -> Self:
"""Computes the sufficient statistics of a batch of inputs."""

# Check if the input is a single dict for when batched_inputs is False
if isinstance(inputs, dict):
inputs_list = [inputs]
else:
inputs_list = inputs

num_examples = 0
feature_stats = collections.defaultdict(FeatureStats)
for example in inputs_list:
num_examples += 1
for key, value in example.items():
try:
feature_stats[key].update(value)
except ValueError as e:
if 'conflicting types' in str(e):
raise ValueError(f'Feature {key} has conflicting types: {e}') from e
raise
return self.__class__(
batched_inputs=self.batched_inputs,
_stats=TfExampleStats(
num_examples=num_examples,
feature_stats=dict(feature_stats),
),
)

def merge(self, other: Self) -> None:
self._stats.num_examples += other.num_examples
for key, value in other.feature_stats.items():
if key in self._stats.feature_stats:
try:
self._stats.feature_stats[key].merge(value)
except ValueError as e:
if 'conflicting types' in str(e):
raise ValueError(f'Feature {key} has conflicting types: {e}') from e
raise
else:
self._stats.feature_stats[key] = value

def result(self) -> TfExampleStats:
for feature in self._stats.feature_stats.values():
feature.num_missing = self._stats.num_examples - feature.num_non_missing
if feature.numeric_stats is not None:
feature.numeric_stats.compute_result()
return self._stats


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
# TODO(b/345249574): Add a preprocessing function of len per row.
@dataclasses.dataclass(slots=True)
Expand Down
Loading
Loading