-
Notifications
You must be signed in to change notification settings - Fork 349
[mxfp8 moe training] mxfp8 a2a with d2h sync #3103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3103
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit c8f75c8 with merge base cbd3adb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
3936d21
to
de47c43
Compare
return input_offset_for_remote_rank, output_offset_for_remote_rank, num_rows_to_read | ||
|
||
|
||
class MXFP8AllToAllVSync(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: name it something like ToMXFP8ThenAll2AllThenToBf16
to make it clear what dtype transitions are happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, updated to use more explicit naming. The high precision dtype can be bf16 or fp32 so I used ToMXFP8AllToAllVDequant
|
||
tokens_per_ep_rank = 8192 | ||
dim = 2048 | ||
input_tensor = torch.randn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: if you quantize this to mxfp8 and then dequantize, you can then test for bitwise equality at the end
de47c43
to
ebf0e32
Compare
ebf0e32
to
c8f75c8
Compare
Summary
inputs
(fwd) andgrad_output
(bwd), and passes these padded buffers through the rest of the downstream grouped_mm, scatter_add etc ops. This approach is experimental because if there is sufficient expert load skew at some point during the training run and the overallocated sym mem buffs are not large enough, the run will crash.Test plan
pytest test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py -k MXFP8AllToAllVSyncTest
Benchmarks
Baseline: 4916us
Mxfp8 sync a2a: 4207us (~1.17x speedup)
mxfp8 perf is basically the same as bf16 baseline. Looking at traces to find out why, we can see the all_to_alls themselves (fp8 data and e8m0 scales) are 1.51x faster than the bf16 a2a at 4916us vs (3167+81)=3248us, but the quant/dequant ops are consuming the entire
bf16 trace:

mxfp8 trace:
