Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 30, 2025

Summary

  • Add differentiable mxfp8 a2a implemented using all_to_all_single functional collective, which requires the caller do a d2h sync to get input_splits/output_splits on the host, as required by the all_to_all_single API.
  • This is in contrast to "on device / sync-free" impl using Triton + Symmetric memory ([mxfp8 moe training] mxfp8 a2a working e2e in torchtitan llama4 training; improve tests + bench scripts #3088) which theoretically should achieve better performance in a model like experimental DSV3 in torchtitan, which natively supports overallocated symmetric memory buffers for data exchange of inputs (fwd) and grad_output (bwd), and passes these padded buffers through the rest of the downstream grouped_mm, scatter_add etc ops. This approach is experimental because if there is sufficient expert load skew at some point during the training run and the overallocated sym mem buffs are not large enough, the run will crash.
  • Non-experimental titan dsv3/llama4 do NOT use this method, opting instead to do a d2h sync in order to allocate the exact memory needed for incoming tokens. In such an approach, the mxfp8 on device kernel is actually worse, because it must do a d2h sync to do an expensive, synchronizing aten::item (slice op) to get the actual received tokens from the overallocated buffer (both in fwd and bwd).
  • Therefore, we implement this simpler mxfp8 a2a sync kernel, which has the exact same approach as the non-experimental models in titan described above, only we quant/dequant the inputs outputs.

Test plan

  • Added test case: pytest test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py -k MXFP8AllToAllVSyncTest

Benchmarks

Baseline: 4916us
Mxfp8 sync a2a: 4207us (~1.17x speedup)

mxfp8 perf is basically the same as bf16 baseline. Looking at traces to find out why, we can see the all_to_alls themselves (fp8 data and e8m0 scales) are 1.51x faster than the bf16 a2a at 4916us vs (3167+81)=3248us, but the quant/dequant ops are consuming the entire

bf16 trace:
Screenshot 2025-09-30 at 10 57 12 AM

mxfp8 trace:
Screenshot 2025-09-30 at 10 58 46 AM

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3103

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit c8f75c8 with merge base cbd3adb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 30, 2025
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Sep 30, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 30, 2025
@danielvegamyhre danielvegamyhre force-pushed the mx-a2a-2 branch 2 times, most recently from 3936d21 to de47c43 Compare October 1, 2025 00:21
return input_offset_for_remote_rank, output_offset_for_remote_rank, num_rows_to_read


class MXFP8AllToAllVSync(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: name it something like ToMXFP8ThenAll2AllThenToBf16 to make it clear what dtype transitions are happening

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, updated to use more explicit naming. The high precision dtype can be bf16 or fp32 so I used ToMXFP8AllToAllVDequant


tokens_per_ep_rank = 8192
dim = 2048
input_tensor = torch.randn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: if you quantize this to mxfp8 and then dequantize, you can then test for bitwise equality at the end

@danielvegamyhre danielvegamyhre merged commit f9b5c30 into main Oct 1, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants