Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 27, 2025

Summary

  • With these changes, the mxfp8 a2a working e2e in torchtitan Llama4 training (using this PR in torchtitan: [mxfp8 MoE training] Support mxfp8 all to all in expert parallel torchtitan#1765)
  • Perf is currently worse than bf16 baseline to due d2h sync caused by aten::item call resulting from extracting the actual tokens from the overallocated symmetric memory grad_output buffer. This sym mem buff must be overallocated to account for the fact that the corresponding output from a2a fwd will be variable size.
  • Therefore, my thinking is:
    1. Use this impl in experimental DSV3 model: This impl is more suitable for the experimental DSV3 no-sync model which natively supports this preallocation method by passing the full padded output/grad_input to downstream ops like grouped_mm, scatter_add etc unmodified. With a couple small changes to this mxfp8 impl (e.g., just returning full padded output and grad_input) we can use it there. The reason this method is experimental is because if while it avoids d2h sync, if there is enough skew in expert routing, the job will crash due to insufficient sym mem buffer space to write to during token exchange.
    2. Add new impl for non-experimental DSV3/Llama4 models: we can add a simpler mxfp8 a2a impl that just kicks off 2 async all_to_all_single_autograds on the e4m3 data and e8m0 scales.
Screenshot 2025-09-29 at 10 32 57 PM

Changes

  • When integrating the mxfp8 a2a kernel with torchtitan I hit some CUDA IMA errors, which I've fixed with more robust bounds checking.
  • Disable compile for forward()/backward() at method level, due to compile not playing nicely with class variables
  • Compile to_mx and to_dtype in fwd and bwd
  • Fix name of unit test
  • Update bench script to measure real default_a2a and mxfp8_a2a impls used in torchitan (being added in [mxfp8 MoE training] Support mxfp8 all to all in expert parallel torchtitan#1765)
  • Add option to profile run in bench script

Benchmarks

input_shape         num_splits    bf16_ms    mxfp8_ms
----------------  ------------  ---------  ----------
(16, 8192, 5120)             8     10.684     62.2852

Limitations

  • Extracting actual tokens from overallocated device buffer at end of forward() and backward() causes d2h syncs, hurting perf. Need to think about ways to avoid this.

Copy link

pytorch-bot bot commented Sep 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3088

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dcc9237 with merge base 0d3217d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 27, 2025
@danielvegamyhre danielvegamyhre force-pushed the improve branch 3 times, most recently from 2fb7cb0 to 0250ba0 Compare September 27, 2025 16:53
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Sep 27, 2025
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Sep 30, 2025

@kwen2501 @vkuzo I need try a different approach to get better perf (see PR description) but would like to land these incremental changes, which contain a mxfp8 a2a impl that is now at least e2e functional in torchtitan training.

@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] fix CUDA IMA and improve bench + test scripts [mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts Sep 30, 2025
@vkuzo
Copy link
Contributor

vkuzo commented Sep 30, 2025

we can add a simpler mxfp8 a2a impl that just kicks off 2 async all_to_all_single_autograds on the e4m3 data and e8m0 scales

for a general utility I'd start with this and then iterate, sounds simpler

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@danielvegamyhre danielvegamyhre merged commit cbd3adb into main Sep 30, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants