-
Couldn't load subscription status.
- Fork 581
[mxfp8 moe training] add mxfp8 all to all impl #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.distributed.expert_parallel import ExpertParallel | ||
|
|
||
|
|
||
| class MXExpertParallel(ExpertParallel): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| try: | ||
| from torchao.prototype.moe_training.kernels.mxfp8.comms import ( | ||
| to_mxfp8_a2a_dequant, | ||
| ) | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| "Please install torchao v0.14+ to use MXExpertParallel" | ||
| ) from err | ||
| self._a2a_dispatch_impl = to_mxfp8_a2a_dequant | ||
| self._a2a_combine_impl = to_mxfp8_a2a_dequant | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -752,6 +752,24 @@ class Quantize: | |
| grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM) | ||
| """Quantized training config for grouped GEMMs""" | ||
|
|
||
| expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's not intuitive to let I'd suggest we rename this to |
||
| """ | ||
| All-to-all implementation to use for the token dispatch step in expert parallelism. | ||
| - "default": Directly uses all_to_all_single with inputs/outputs in original precision. | ||
| - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, | ||
| using all_to_all_single on the quantized data and scales, then dequantizing | ||
| the outputs back to original precision. | ||
| """ | ||
|
|
||
| expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default" | ||
| """ | ||
| All-to-all implementation to use for the token combine step in expert parallelism. | ||
| - "default": Directly uses all_to_all_single with inputs/outputs in original precision. | ||
| - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, | ||
| using all_to_all_single on the quantized data and scales, then dequantizing | ||
| the outputs back to original precision. | ||
| """ | ||
|
|
||
|
|
||
| @dataclass | ||
| class Comm: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| RowwiseParallel, | ||
| SequenceParallel, | ||
| ) | ||
| from torchtitan.components.quantization.mx.expert_parallel import MXExpertParallel | ||
| from torchtitan.config import JobConfig, TORCH_DTYPE_MAP | ||
| from torchtitan.config.job_config import Compile as CompileConfig | ||
| from torchtitan.distributed import NoParallel, ParallelDims | ||
|
|
@@ -98,6 +99,7 @@ def parallelize_llama( | |
| if parallel_dims.tp_enabled or parallel_dims.ep_enabled: | ||
| apply_moe_ep_tp( | ||
| model, | ||
| job_config, | ||
| tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, | ||
| ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, | ||
| ep_tp_mesh=( | ||
|
|
@@ -436,13 +438,34 @@ def apply_fsdp( | |
|
|
||
| def apply_moe_ep_tp( | ||
| model: nn.Module, | ||
| job_config: JobConfig, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's only send in |
||
| tp_mesh: DeviceMesh | None, | ||
| ep_mesh: DeviceMesh | None, | ||
| ep_tp_mesh: DeviceMesh | None, | ||
| etp_enabled: bool, | ||
| ): | ||
| assert ep_mesh is not None or tp_mesh is not None | ||
|
|
||
| EP_IMPLS = { | ||
| "default": ExpertParallel, | ||
| "mxfp8": MXExpertParallel, | ||
| } | ||
|
Comment on lines
+449
to
+452
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think some if-else on quantize_config is enough, no need to have this |
||
| assert ( | ||
| job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS | ||
| ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}" | ||
| assert ( | ||
| job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS | ||
| ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}" | ||
|
Comment on lines
+453
to
+458
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do this in quantized EP class, if you adopt the |
||
|
|
||
| logger.info( | ||
| f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}" | ||
| ) | ||
| logger.info( | ||
| f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}" | ||
| ) | ||
|
Comment on lines
+460
to
+465
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to the comment above, we should print these override info only when they are overridden during quantization. |
||
|
|
||
| ep_class = EP_IMPLS[job_config.quantize.expert_parallel_a2a_dispatch_impl] | ||
|
|
||
| for transformer_block in model.layers.values(): | ||
| if not transformer_block.moe_enabled: | ||
| continue | ||
|
|
@@ -491,7 +514,7 @@ def apply_moe_ep_tp( | |
| elif tp_mesh is None or not etp_enabled: | ||
| experts_mesh = ep_mesh | ||
| # input / output sharding on the batch / tokens dim | ||
| experts_plan = ExpertParallel() | ||
| experts_plan = ep_class() | ||
| else: | ||
| experts_mesh = ep_tp_mesh | ||
| experts_plan = ExpertTensorParallel() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait you are changing both together, not independently