diff --git a/clu/metrics.py b/clu/metrics.py index 803f2ae..b792cef 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -193,8 +193,44 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]: # pylint: disable-next=protected-access return reduced._reduce_merge(metric), None - first = jax.tree_util.tree_map(lambda x: x[0], self) - remainder = jax.tree_util.tree_map(lambda x: x[1:], self) + # Avoid degraded performance under the new jax.pmap. See + # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. + # Only use the sharding path for concrete sharded arrays, not tracers. + def _is_concrete_sharded(x): + if isinstance(x, jax.core.Tracer): + return False + if not hasattr(x, "addressable_shards"): + return False + shards = x.addressable_shards + if not shards: + return False + # Only use sharding path when shards have shape (1, ...) from pmap + return shards[0].data.ndim > 0 and shards[0].data.shape[0] == 1 + + leaves = jax.tree_util.tree_leaves(self) + use_sharding_path = ( + jax.config.jax_pmap_shmap_merge + and leaves + and _is_concrete_sharded(leaves[0]) + ) + + if use_sharding_path: + + def get_first(x): + return x.addressable_shards[0].data.squeeze(0) + + def get_remainder(x): + shards = x.addressable_shards + if len(shards) <= 1: + shape = shards[0].data.squeeze(0).shape + return jnp.empty((0,) + shape, dtype=shards[0].data.dtype) + return jnp.stack([s.data.squeeze(0) for s in shards[1:]], axis=0) + + first = jax.tree_util.tree_map(get_first, self) + remainder = jax.tree_util.tree_map(get_remainder, self) + else: + first = jax.tree_util.tree_map(lambda x: x[0], self) + remainder = jax.tree_util.tree_map(lambda x: x[1:], self) # According to b/160868467#comment4, usage of `jax.lax.scan` does not add a # significant computational cost for simple metrics where e.g. `jnp.sum` # could be used instead.