Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with at[index].set(value) when using jit and shard_map in multi-host setting. #24768

Open
lollcat opened this issue Nov 7, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@lollcat
Copy link

lollcat commented Nov 7, 2024

Description

My program fails silently when using at[index].set(value) in the multi-host setting with shard_map and jit.

Here is a minimal example to reproduce the error:

    import jax

    jax.distributed.initialize()

    import jax.numpy as jnp
    from jax.experimental import multihost_utils
    from jax.experimental.shard_map import shard_map
    from jax.sharding import Mesh
    from jax.sharding import PartitionSpec as P

    use_jit = True
    use_at_set = True

    data_axis_name = "data"
    x = jnp.ones(10)

    mesh = Mesh(jax.devices(), (data_axis_name,))

    x = multihost_utils.host_local_array_to_global_array(
        local_inputs=x,
        global_mesh=mesh,
        pspecs=P(),
    )

    def simple_fn(x):
        mean = jnp.sum(x)
        if use_at_set:
            out = x.at[-1].set(mean)
        else:
            mask = jnp.zeros(x.shape, dtype=bool).at[-1].set(True)
            out = jnp.where(mask, mean, x)
        return out

    shmap_simple_fn = shard_map(
        simple_fn,
        mesh=mesh,
        in_specs=P(),
        out_specs=P(),
    )
    if use_jit:
        shmap_simple_fn = jax.jit(shmap_simple_fn)
    out = shmap_simple_fn(x)

    out = multihost_utils.global_array_to_host_local_array(out, global_mesh=mesh, pspecs=P())
    jax.block_until_ready(out)
    print(out)

if I set use_jit = False or use_at_set=False then the program runs successfully. The program also runs successfully with use_jit = True and use_at_set=True in the single host, multiple-device setting.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.4
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
device info: TPU v4-32, 4 local devices"
process_count: 8
platform: uname_result(system='Linux', node='t1v-n-7e0ce5ae-w-0', release='5.13.0-1023-gcp', version='#28~20.04.1-Ubuntu SMP Wed Mar 30 03:51:07 UTC 2022', machine='x86_64')
@lollcat lollcat added the bug Something isn't working label Nov 7, 2024
@sharadmv
Copy link
Collaborator

What is the failure mode exactly? Do you get an error? Incorrect values?

@lollcat
Copy link
Author

lollcat commented Nov 12, 2024

It fails silently - the program reaches jax.block_until_ready(out) and then terminates without raising an error, before reaching the final print(out) line.

If run in the single-host setting, or with use_jit = False, or use_at_set=False then the program correctly prints [ 1. 1. 1. 1. 1. 1. 1. 1. 1. 10.].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants