Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 201 additions & 158 deletions ml_metrics/_src/aggregates/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,183 @@ def __str__(self):
return f'count: {self.result()}'


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True, eq=True)
class Mean(chainable.CallableMetric):
"""Computes the mean and variance of a batch of values."""

batch_score_fn: Callable[..., types.NumbersT] | None = None
_count: types.NumbersT = 0
_mean: types.NumbersT = np.nan
_input_shape: tuple[int, ...] = ()

def as_agg_fn(self, *, nested: bool = False) -> chainable.AggregateFn:
return chainable.as_agg_fn(
self.__class__,
batch_score_fn=self.batch_score_fn if not nested else None,
nested=nested,
agg_preprocess_fn=self.batch_score_fn if nested else None,
)

def new(self, batch: types.NumbersT) -> types.NumbersT:
"""Returns a new instance with the sufficient statistics reset and updated.

If `batch_score_fn` is provided, it will evaluate the batch and assign a
score to each item. Subsequently, the statistics are computed based on
non-nan values within the batch. If a certain dimension in batch is all nan,
mean and variance corresponding to that dimension will be nan, count for
that dimension will be 0.

Args:
batch: A non-vacant series of values.

Returns:
Mean
"""
batch = np.asarray(
self.batch_score_fn(batch) if self.batch_score_fn else batch
)
return self.__class__(
_count=np.sum(~np.isnan(batch), axis=0),
_mean=np.nanmean(batch, axis=0),
_input_shape=batch.shape if batch.size else (),
)

@property
def count(self) -> types.NumbersT:
return self._count

@property
def mean(self) -> types.NumbersT:
return self._mean

@property
def total(self) -> types.NumbersT:
return math_utils.where(self._count > 0, self._mean * self._count, 0)

@property
def input_shape(self) -> tuple[int, ...]:
return self._input_shape

def merge(self, other: Self) -> None:
if np.all(np.isnan(other.mean)):
return
self._input_shape = self._input_shape or other.input_shape
if other.input_shape and other.input_shape[1:] != self._input_shape[1:]:
raise ValueError(
f'Incompatible shape {other.input_shape} while the'
f' other have shape {self._input_shape}.'
)
self._count += other.count
mean_diff = math_utils.nanadd(other.mean, -self._mean)
update = mean_diff * math_utils.safe_divide(other.count, self._count)
self._mean = math_utils.nanadd(self._mean, update)

def result(self) -> Self:
return self.mean

def __str__(self):
return f'mean: {self.mean}'


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True, eq=True)
class MeanAndVariance(Mean):
"""Computes the mean and variance of a batch of values."""

_var: types.NumbersT = np.nan

def new(self, batch: types.NumbersT) -> types.NumbersT:
batch = np.asarray(
self.batch_score_fn(batch) if self.batch_score_fn else batch
)
return self.__class__( # pytype: disable=wrong-keyword-args
_count=np.sum(~np.isnan(batch), axis=0),
_mean=np.nanmean(batch, axis=0),
_var=np.nanvar(batch, axis=0),
_input_shape=batch.shape if batch.size else (),
)

@property
def var(self) -> types.NumbersT:
return self._var

@property
def stddev(self) -> types.NumbersT:
return np.sqrt(self._var)

def merge(self, other: Self) -> None:
if np.all(np.isnan(other.var)):
return
prev_mean, prev_count = np.copy(self._mean), np.copy(self._count)
super().merge(other)
if np.all(np.isnan(self._var)):
self._var = other.var
return
# Reference
# (https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups)
prev_count_ratio = math_utils.safe_divide(prev_count, self._count)
other_count_ratio = math_utils.safe_divide(other.count, self._count)
delta_mean = math_utils.nanadd(self._mean, -prev_mean)
mean_diff = math_utils.nanadd(other.mean, -self._mean)
self._var = (
prev_count_ratio * self._var
+ other_count_ratio * other.var
+ prev_count_ratio * delta_mean**2
+ other_count_ratio * mean_diff**2
)

def result(self) -> types.NumbersT:
return self

def __str__(self):
return (
f'count: {self.count}, total: {self.total}, mean: {self.mean}, '
f'var: {self.var}, stddev: {self.stddev}'
)


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
class Var(MeanAndVariance):

def result(self) -> types.NumbersT:
return self.var

def __str__(self):
return f'var: {self.var}'


@dataclasses.dataclass(slots=True)
class NumericStats:
"""Statistics for a numeric feature."""

min: float = np.inf
max: float = -np.inf
num_zeros: int = 0
mean: float = np.nan
std_dev: float = np.nan
mva: MeanAndVariance = dataclasses.field(default_factory=MeanAndVariance)

def update(self, value: list[Any]):
self.mva.add(value)
self.min = np.minimum(self.min, np.min(value))
self.max = np.maximum(self.max, np.max(value))
self.num_zeros += np.sum(np.equal(value, 0))

def merge(self, other: Self):
if other.mva.count > 0:
self.mva.merge(other.mva)
self.min = min(self.min, other.min)
self.max = max(self.max, other.max)
self.num_zeros += other.num_zeros

def compute_result(self):
"""Computes the final statistics and returns a new NumericStats."""
mva_res = self.mva.result()
self.mean = mva_res.mean
self.std_dev = mva_res.stddev


@dataclasses.dataclass(slots=True)
class FeatureStats:
"""Statistics for a single feature."""
Expand All @@ -384,18 +561,24 @@ class FeatureStats:
min_num_values: int | None = None
tot_num_values: int = 0
avg_num_values: float = 0.0

def update(self, length: int, feature_type: FeatureType | None):
self.merge(
FeatureStats(
feature_type=feature_type,
num_non_missing=1,
max_num_values=length,
min_num_values=length,
tot_num_values=length,
avg_num_values=float(length),
)
numeric_stats: NumericStats | None = None

def update(self, value: list[Any]):
"""Updates the feature stats with a new value."""
feature_type = _get_feature_type(value)
length = len(value)
fs = FeatureStats(
feature_type=feature_type,
num_non_missing=1,
max_num_values=length,
min_num_values=length,
tot_num_values=length,
avg_num_values=float(length),
)
if feature_type in [FeatureType.INT, FeatureType.FLOAT] and value:
fs.numeric_stats = NumericStats()
fs.numeric_stats.update(value)
self.merge(fs)

def merge(self, other: Self):
"""Merges with other feature stats."""
Expand All @@ -421,6 +604,10 @@ def merge(self, other: Self):
if self.num_non_missing
else 0.0
)
if other.numeric_stats is not None:
if self.numeric_stats is None:
self.numeric_stats = NumericStats()
self.numeric_stats.merge(other.numeric_stats)


@dataclasses.dataclass(slots=True)
Expand Down Expand Up @@ -509,7 +696,7 @@ def new(
num_examples += 1
for key, value in example.items():
try:
feature_stats[key].update(len(value), _get_feature_type(value))
feature_stats[key].update(value)
except ValueError as e:
if 'conflicting types' in str(e):
raise ValueError(f'Feature {key} has conflicting types: {e}') from e
Expand Down Expand Up @@ -538,155 +725,11 @@ def merge(self, other: Self) -> None:
def result(self) -> TfExampleStats:
for feature in self._stats.feature_stats.values():
feature.num_missing = self._stats.num_examples - feature.num_non_missing
if feature.numeric_stats is not None:
feature.numeric_stats.compute_result()
return self._stats


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True, eq=True)
class Mean(chainable.CallableMetric):
"""Computes the mean and variance of a batch of values."""

batch_score_fn: Callable[..., types.NumbersT] | None = None
_count: types.NumbersT = 0
_mean: types.NumbersT = np.nan
_input_shape: tuple[int, ...] = ()

def as_agg_fn(self, *, nested: bool = False) -> chainable.AggregateFn:
return chainable.as_agg_fn(
self.__class__,
batch_score_fn=self.batch_score_fn if not nested else None,
nested=nested,
agg_preprocess_fn=self.batch_score_fn if nested else None,
)

def new(self, batch: types.NumbersT) -> types.NumbersT:
"""Returns a new instance with the sufficient statistics reset and updated.

If `batch_score_fn` is provided, it will evaluate the batch and assign a
score to each item. Subsequently, the statistics are computed based on
non-nan values within the batch. If a certain dimension in batch is all nan,
mean and variance corresponding to that dimension will be nan, count for
that dimension will be 0.

Args:
batch: A non-vacant series of values.

Returns:
Mean
"""
batch = np.asarray(
self.batch_score_fn(batch) if self.batch_score_fn else batch
)
return self.__class__(
_count=np.sum(~np.isnan(batch), axis=0),
_mean=np.nanmean(batch, axis=0),
_input_shape=batch.shape if batch.size else (),
)

@property
def count(self) -> types.NumbersT:
return self._count

@property
def mean(self) -> types.NumbersT:
return self._mean

@property
def total(self) -> types.NumbersT:
return math_utils.where(self._count > 0, self._mean * self._count, 0)

@property
def input_shape(self) -> tuple[int, ...]:
return self._input_shape

def merge(self, other: Self) -> None:
if np.all(np.isnan(other.mean)):
return
self._input_shape = self._input_shape or other.input_shape
if other.input_shape and other.input_shape[1:] != self._input_shape[1:]:
raise ValueError(
f'Incompatible shape {other.input_shape} while the'
f' other have shape {self._input_shape}.'
)
self._count += other.count
mean_diff = math_utils.nanadd(other.mean, -self._mean)
update = mean_diff * math_utils.safe_divide(other.count, self._count)
self._mean = math_utils.nanadd(self._mean, update)

def result(self) -> Self:
return self.mean

def __str__(self):
return f'mean: {self.mean}'


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
@dataclasses.dataclass(kw_only=True, eq=True)
class MeanAndVariance(Mean):
"""Computes the mean and variance of a batch of values."""

_var: types.NumbersT = np.nan

def new(self, batch: types.NumbersT) -> types.NumbersT:
batch = np.asarray(
self.batch_score_fn(batch) if self.batch_score_fn else batch
)
return self.__class__( # pytype: disable=wrong-keyword-args
_count=np.sum(~np.isnan(batch), axis=0),
_mean=np.nanmean(batch, axis=0),
_var=np.nanvar(batch, axis=0),
_input_shape=batch.shape if batch.size else (),
)

@property
def var(self) -> types.NumbersT:
return self._var

@property
def stddev(self) -> types.NumbersT:
return np.sqrt(self._var)

def merge(self, other: Self) -> None:
if np.all(np.isnan(other.var)):
return
prev_mean, prev_count = np.copy(self._mean), np.copy(self._count)
super().merge(other)
if np.all(np.isnan(self._var)):
self._var = other.var
return
# Reference
# (https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups)
prev_count_ratio = math_utils.safe_divide(prev_count, self._count)
other_count_ratio = math_utils.safe_divide(other.count, self._count)
delta_mean = math_utils.nanadd(self._mean, -prev_mean)
mean_diff = math_utils.nanadd(other.mean, -self._mean)
self._var = (
prev_count_ratio * self._var
+ other_count_ratio * other.var
+ prev_count_ratio * delta_mean**2
+ other_count_ratio * mean_diff**2
)

def result(self) -> types.NumbersT:
return self

def __str__(self):
return (
f'count: {self.count}, total: {self.total}, mean: {self.mean}, '
f'var: {self.var}, stddev: {self.stddev}'
)


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
class Var(MeanAndVariance):

def result(self) -> types.NumbersT:
return self.var

def __str__(self):
return f'var: {self.var}'


@telemetry.class_monitor(api='ml_metrics', category=telemetry.CATEGORY.STATS)
# TODO(b/345249574): Add a preprocessing function of len per row.
@dataclasses.dataclass(slots=True)
Expand Down
Loading
Loading