-
Notifications
You must be signed in to change notification settings - Fork 569
[mxfp8 moe training] add mxfp8 all to all impl #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
617a819
to
bb48557
Compare
bb48557
to
795318e
Compare
|
||
def apply_moe_ep_tp( | ||
model: nn.Module, | ||
job_config: JobConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's only send in job_config.quantize
instead of the whole job_config
Also let's make it a non-positional arg to avoid unnecessary BC breaking
grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM) | ||
"""Quantized training config for grouped GEMMs""" | ||
|
||
expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not intuitive to let Quantize
to dictate what a2a impl should be. Instead we should let Quantize
to override the default setting.
I'd suggest we rename this to override_ep_a2a_dispatch
. Since we only have mxfp8 for now, you can go with bool
, or you could make it Literal | None
.
self._a2a_dispatch_impl = to_mxfp8_a2a_dequant | ||
self._a2a_combine_impl = to_mxfp8_a2a_dequant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait you are changing both together, not independently
logger.info( | ||
f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}" | ||
) | ||
logger.info( | ||
f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to the comment above, we should print these override info only when they are overridden during quantization.
assert ( | ||
job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS | ||
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}" | ||
assert ( | ||
job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS | ||
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do this in quantized EP class, if you adopt the Literal
version of config
EP_IMPLS = { | ||
"default": ExpertParallel, | ||
"mxfp8": MXExpertParallel, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some if-else on quantize_config is enough, no need to have this EP_IMPLS
variable here.
Summary
"default"
or"mxfp8"
impl"mxfp8"
impl uses torchao's newto_mxfp8_a2a_dequant
, which has the exact same API as functional collectiveall_to_all_single_autograd
and is differentiable, so it can be used as a drop-in replacement for the default a2a impl.to_mxfp8_a2a_dequant
works as follows:Performance
Single node benchmarks with 4xB200
Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8
Reduced num layers from 48 -> 2 to avoid OOM in single node setting
Debug model config:
Additional context on design/implementation choices
Additional background on motivation
30% of llama4 model profiled runtime is all2all comms
47% avg runtime devoted to MoE comms in profiled OSS models