Error with at[index].set(value)
when using jit
and shard_map
in multi-host setting.
#24768
Labels
bug
Something isn't working
Description
My program fails silently when using
at[index].set(value)
in the multi-host setting withshard_map
andjit
.Here is a minimal example to reproduce the error:
if I set
use_jit = False
oruse_at_set=False
then the program runs successfully. The program also runs successfully withuse_jit = True
anduse_at_set=True
in the single host, multiple-device setting.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: