Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):


M = TypeVar("M", bound="Metric")
R = TypeVar("R", jnp.ndarray, dict[str, jnp.ndarray])
V = TypeVar("V", clu.values.Value, dict[str, clu.values.Value])


class Metric:
Expand Down Expand Up @@ -160,7 +162,7 @@ def merge(self: M, other: M) -> M:
def _reduce_merge(self: M, other: M) -> M:
return self.merge(other)

def compute(self) -> jnp.ndarray:
def compute(self) -> R:
"""Computes final metrics from intermediate values."""
raise NotImplementedError("Must override compute()")

Expand All @@ -169,9 +171,13 @@ def empty(cls: type[M]) -> M:
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
raise NotImplementedError("Must override empty()")

def compute_value(self) -> clu.values.Value:
"""Wraps compute() and returns a values.Value."""
return clu.values.Scalar(self.compute())
def compute_value(self) -> V:
"""Wraps compute() and returns a values.Value or dict of values.Value."""
result = self.compute()
if isinstance(result, dict):
return {k: clu.values.Scalar(v) for k, v in result.items()}
else:
return clu.values.Scalar(result)

def reduce(self: M) -> M:
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
Expand Down Expand Up @@ -623,22 +629,34 @@ def reduce(self: C) -> C:
})

def compute(self) -> dict[str, jnp.ndarray]:
"""Returns a dictionary mapping metric field name to `Metric.compute()`."""
"""Returns a dictionary mapping metrics to their computed values."""
_check_reduction_counter_ndim(self._reduction_counter)
return {
metric_name: metric.compute()
for metric_name, metric in vars(self).items()
if metric_name != "_reduction_counter"
}
metric_results = {}
for metric_name, metric in vars(self).items():
if metric_name != "_reduction_counter":
metric_result = metric.compute()
if isinstance(metric_result, dict):
metric_results.update(
{f"{metric_name}/{k}": v for k, v in metric_result.items()}
)
else:
metric_results[metric_name] = metric_result
return metric_results

def compute_values(self) -> dict[str, clu.values.Value]:
"""Computes metrics and returns them as clu.values.Value."""
"""Computes metrics and returns them as clu_values.Value."""
_check_reduction_counter_ndim(self._reduction_counter)
return {
metric_name: metric.compute_value()
for metric_name, metric in vars(self).items()
if metric_name != "_reduction_counter"
}
metric_results = {}
for metric_name, metric in vars(self).items():
if metric_name != "_reduction_counter":
metric_result = metric.compute_value()
if isinstance(metric_result, dict):
metric_results.update(
{f"{metric_name}/{k}": v for k, v in metric_result.items()}
)
else:
metric_results[metric_name] = metric_result
return metric_results

def unreplicate(self: C) -> C:
"""Short-hand for `flax.jax_utils.unreplicate(self)`.
Expand Down