Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuron support in Axlearn #566

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link

This PR enables use of neuron devices in Axlearn for model training.

  • Chooses correct mesh for TRN devices for Fuji 7B with the mesh selector flag --mesh_selector=neuron-trn1.32xlarge-64

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

axlearn/common/utils.py Outdated Show resolved Hide resolved
@@ -167,6 +167,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be worth listing the step times for different configurations, similar to the other mesh rules.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

I am launching a fsdp=8 job with 8 nodes. The job is blocked due to AWS capacity. Hope to get some data to share by Friday

The previous response from AWS was that FSDP is slower due to higher communication overhead.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor parallel (model) is more performant on trn1 arch

@@ -167,6 +167,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be worth listing the step times for different configurations, similar to the other mesh rules.

axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
axlearn/experiments/text/gpt/common.py Outdated Show resolved Hide resolved
@@ -267,12 +269,17 @@ def model_config(
batch_axis_names=batch_axis_names,
seq_axis_names="seq",
)

device_platform = np.asarray(jax.devices())[0].platform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.devices() during config building may be an unexpected dependency on global state -- should we take a platform arg or similar?

Copy link
Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it, but I followed the pattern already used here

devices = jax.devices()

Please let me know if the platform flag is necessary, I can add it. Thanks!

@kelvin-zou
Copy link
Contributor

@apoorvtintin I see this PR is quite stale for sometime.
If no objection, I'd like to have @Ruixuan who is working on Trn from our end to port your change and continue iterate it?

@ptoulme-aws
Copy link

@apoorvtintin I see this PR is quite stale for sometime. If no objection, I'd like to have @Ruixuan who is working on Trn from our end to port your change and continue iterate it?

Apoorv is on PTO right now. I am OK with you all taking over this PR. Can you add us as a reviewer when you finish? Thanks

@apoorvtintin
Copy link
Author

Thanks for all the reviews, I fixed most of the comments on the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants