Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,44 @@ def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
# pylint: disable-next=protected-access
return reduced._reduce_merge(metric), None

first = jax.tree_util.tree_map(lambda x: x[0], self)
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
# Avoid degraded performance under the new jax.pmap. See
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
# Only use the sharding path for concrete sharded arrays, not tracers.
def _is_concrete_sharded(x):
if isinstance(x, jax.core.Tracer):
return False
if not hasattr(x, "addressable_shards"):
return False
shards = x.addressable_shards
if not shards:
return False
# Only use sharding path when shards have shape (1, ...) from pmap
return shards[0].data.ndim > 0 and shards[0].data.shape[0] == 1

leaves = jax.tree_util.tree_leaves(self)
use_sharding_path = (
jax.config.jax_pmap_shmap_merge
and leaves
and _is_concrete_sharded(leaves[0])
)

if use_sharding_path:

def get_first(x):
return x.addressable_shards[0].data.squeeze(0)

def get_remainder(x):
shards = x.addressable_shards
if len(shards) <= 1:
shape = shards[0].data.squeeze(0).shape
return jnp.empty((0,) + shape, dtype=shards[0].data.dtype)
return jnp.stack([s.data.squeeze(0) for s in shards[1:]], axis=0)

first = jax.tree_util.tree_map(get_first, self)
remainder = jax.tree_util.tree_map(get_remainder, self)
else:
first = jax.tree_util.tree_map(lambda x: x[0], self)
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
# According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
# significant computational cost for simple metrics where e.g. `jnp.sum`
# could be used instead.
Expand Down
Loading