Skip to content

[pmap] Avoid degraded performance under the new jax.pmap.#375

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_846723848
Jan 7, 2026
Merged

[pmap] Avoid degraded performance under the new jax.pmap.#375
copybara-service[bot] merged 1 commit intomainfrom
test_846723848

Conversation

@copybara-service
Copy link

[pmap] Avoid degraded performance under the new jax.pmap.

This change prepares for the new jax.pmap by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with jax.pmap is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.

The change is under the jax_pmap_shmap_merge configuration flag. If True, the new jax.pmap implementation based on jax.jit(jax.shard_map) is used and requires the new explicit shard access. If False, the old jax.pmap implementation is used and there is a special case in how x[0] works.

Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays

@copybara-service copybara-service bot force-pushed the test_846723848 branch 2 times, most recently from 0572bfa to f23d5c6 Compare January 7, 2026 15:02
PiperOrigin-RevId: 853251958
@copybara-service copybara-service bot merged commit 83ac17a into main Jan 7, 2026
@copybara-service copybara-service bot deleted the test_846723848 branch January 7, 2026 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments