diff --git a/clu/metrics.py b/clu/metrics.py index 103d871..d8e247d 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -57,6 +57,7 @@ def evaluate(variables_p, test_ds): """ from __future__ import annotations from collections.abc import Mapping, Sequence +import inspect from typing import Any, TypeVar, Protocol from absl import logging @@ -536,7 +537,8 @@ def empty(cls: type[C]) -> C: _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), **{ metric_name: metric.empty() - for metric_name, metric in cls.__annotations__.items() + for metric_name, metric + in inspect.get_annotations(cls, eval_str=True).items() }) @classmethod @@ -546,7 +548,8 @@ def _from_model_output(cls: type[C], **kwargs) -> C: _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), **{ metric_name: metric.from_model_output(**kwargs) - for metric_name, metric in cls.__annotations__.items() + for metric_name, metric + in inspect.get_annotations(cls, eval_str=True).items() }) @classmethod