diff --git a/clu/metrics.py b/clu/metrics.py index d8e247d..dd858f5 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -92,6 +92,8 @@ def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray): M = TypeVar("M", bound="Metric") +R = TypeVar("R", jnp.ndarray, dict[str, jnp.ndarray]) +V = TypeVar("V", clu.values.Value, dict[str, clu.values.Value]) class Metric: @@ -160,7 +162,7 @@ def merge(self: M, other: M) -> M: def _reduce_merge(self: M, other: M) -> M: return self.merge(other) - def compute(self) -> jnp.ndarray: + def compute(self) -> R: """Computes final metrics from intermediate values.""" raise NotImplementedError("Must override compute()") @@ -169,9 +171,13 @@ def empty(cls: type[M]) -> M: """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op).""" raise NotImplementedError("Must override empty()") - def compute_value(self) -> clu.values.Value: - """Wraps compute() and returns a values.Value.""" - return clu.values.Scalar(self.compute()) + def compute_value(self) -> V: + """Wraps compute() and returns a values.Value or dict of values.Value.""" + result = self.compute() + if isinstance(result, dict): + return {k: clu.values.Scalar(v) for k, v in result.items()} + else: + return clu.values.Scalar(result) def reduce(self: M) -> M: """Reduces the metric along it first axis by calling `_reduce_merge()`. @@ -623,22 +629,34 @@ def reduce(self: C) -> C: }) def compute(self) -> dict[str, jnp.ndarray]: - """Returns a dictionary mapping metric field name to `Metric.compute()`.""" + """Returns a dictionary mapping metrics to their computed values.""" _check_reduction_counter_ndim(self._reduction_counter) - return { - metric_name: metric.compute() - for metric_name, metric in vars(self).items() - if metric_name != "_reduction_counter" - } + metric_results = {} + for metric_name, metric in vars(self).items(): + if metric_name != "_reduction_counter": + metric_result = metric.compute() + if isinstance(metric_result, dict): + metric_results.update( + {f"{metric_name}/{k}": v for k, v in metric_result.items()} + ) + else: + metric_results[metric_name] = metric_result + return metric_results def compute_values(self) -> dict[str, clu.values.Value]: - """Computes metrics and returns them as clu.values.Value.""" + """Computes metrics and returns them as clu_values.Value.""" _check_reduction_counter_ndim(self._reduction_counter) - return { - metric_name: metric.compute_value() - for metric_name, metric in vars(self).items() - if metric_name != "_reduction_counter" - } + metric_results = {} + for metric_name, metric in vars(self).items(): + if metric_name != "_reduction_counter": + metric_result = metric.compute_value() + if isinstance(metric_result, dict): + metric_results.update( + {f"{metric_name}/{k}": v for k, v in metric_result.items()} + ) + else: + metric_results[metric_name] = metric_result + return metric_results def unreplicate(self: C) -> C: """Short-hand for `flax.jax_utils.unreplicate(self)`.