Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuron support in Axlearn #566

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,10 @@ def create_device_mesh(
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
return build_standard_mesh(mesh_shape, devices=devices)

# Neuron also only uses standard mesh
if device_platform == "neuron":
return build_standard_mesh(mesh_shape, devices=devices)
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved

# We only break the first device axis (the least communication intensive) across granules.
assert (
ici_mesh_shape[0] % num_granules == 0
Expand Down
9 changes: 8 additions & 1 deletion axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
"""

import math
import numpy as np
from typing import Dict, List, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
import tensorflow as tf
from jax.sharding import PartitionSpec
Expand Down Expand Up @@ -267,12 +269,17 @@ def model_config(
batch_axis_names=batch_axis_names,
seq_axis_names="seq",
)

device_platform = np.asarray(jax.devices())[0].platform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.devices() during config building may be an unexpected dependency on global state -- should we take a platform arg or similar?

Copy link
Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it, but I followed the pattern already used here

devices = jax.devices()

Please let me know if the platform flag is necessary, I can add it. Thanks!

# neuron uses Zero 3
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != 'neuron' else ("data", "expert", "fsdp", "seq")
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved

cfg.dtype = jnp.float32
# Shard some FFN and attention weights over multiple axes.
set_double_shard_weights_config(
cfg.decoder.transformer.layer,
batch_axis_names=batch_axis_names,
fsdp_axis_names=("expert", "fsdp", "seq"),
fsdp_axis_names=fsdp_axis_names,
tp_axis_names="model",
seq_axis_names=("seq",),
)
Expand Down
6 changes: 5 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class Version(enum.Enum):
},
}

TRN_MODEL_AXIS_SIZE=8
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved

def get_trainer_kwargs(
model_size: str,
Expand All @@ -103,7 +104,6 @@ def get_trainer_kwargs(
num_kv_heads = 8

rope_theta = ROPE_THETA[version]

# dict() is more readable here.
# pylint: disable=use-dict-literal
if model_size == "test":
Expand Down Expand Up @@ -167,6 +167,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be worth listing the step times for different configurations, similar to the other mesh rules.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

I am launching a fsdp=8 job with 8 nodes. The job is blocked due to AWS capacity. Hope to get some data to share by Friday

The previous response from AWS was that FSDP is slower due to higher communication overhead.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor parallel (model) is more performant on trn1 arch

),
),
)
elif model_size == "70B":
Expand Down