Commit 0572bfa
[pmap] Avoid degraded performance under the new
This change prepares for the new `jax.pmap` by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with `jax.pmap` is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.
The change is under the `jax_pmap_shmap_merge` configuration flag. If `True`, the new `jax.pmap` implementation based on `jax.jit(jax.shard_map)` is used and requires the new explicit shard access. If `False`, the old `jax.pmap` implementation is used and there is a special case in how `x[0]` works.
Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays
PiperOrigin-RevId: 846723848jax.pmap.1 parent 52dff62 commit 0572bfa
1 file changed
+38
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
193 | 193 | | |
194 | 194 | | |
195 | 195 | | |
196 | | - | |
197 | | - | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
198 | 234 | | |
199 | 235 | | |
200 | 236 | | |
| |||
0 commit comments