diff --git a/clu/metrics.py b/clu/metrics.py index a0cf8e7..103d871 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -864,7 +864,7 @@ def compute(self) -> Any: variance = self.sum_of_squares / self.count - mean**2 # Mathematically variance can never be negative but in reality we may run # into such issues due to numeric reasons. - variance = jnp.clip(variance, a_min=0.0) + variance = jnp.clip(variance, min=0.0) return variance**.5