diff --git a/torchtitan/components/quantization/__init__.py b/torchtitan/components/quantization/__init__.py index de94c37b3e..b8d186c92f 100644 --- a/torchtitan/components/quantization/__init__.py +++ b/torchtitan/components/quantization/__init__.py @@ -58,5 +58,5 @@ def _validate(job_config: JobConfig): # Import to register quantization modules as ModelConverter # (imports down here to avoid circular imports with QuantizationConverter) -import torchtitan.components.quantization.float8 # noqa: F401 -import torchtitan.components.quantization.mx # noqa: F401 +import torchtitan.components.quantization.float8.converters # noqa: F401 +import torchtitan.components.quantization.mx.converters # noqa: F401 diff --git a/torchtitan/components/quantization/float8/__init__.py b/torchtitan/components/quantization/float8/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/components/quantization/float8/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8/converters.py similarity index 99% rename from torchtitan/components/quantization/float8.py rename to torchtitan/components/quantization/float8/converters.py index 86932a17bd..c6aad826b5 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8/converters.py @@ -11,6 +11,7 @@ FP8_GROUP_ALIGNMENT_SIZE, QuantizationConverter, ) +from torchtitan.components.quantization.utils import module_filter_fn from torchtitan.config.job_config import Float8Linear, JobConfig from torchtitan.distributed import ParallelDims @@ -19,8 +20,6 @@ from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability -from .utils import module_filter_fn - AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn" diff --git a/torchtitan/components/quantization/mx/__init__.py b/torchtitan/components/quantization/mx/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/components/quantization/mx/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx/converters.py similarity index 98% rename from torchtitan/components/quantization/mx.py rename to torchtitan/components/quantization/mx/converters.py index a474cc3918..ceffdc4260 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx/converters.py @@ -13,6 +13,7 @@ MXFP8_GROUP_ALIGNMENT_SIZE, QuantizationConverter, ) +from torchtitan.components.quantization.utils import module_filter_fn from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims @@ -21,8 +22,6 @@ from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability -from .utils import module_filter_fn - class MXLinearConverter(QuantizationConverter): """Converts the linear layers of `model` to `MXLinear`.""" diff --git a/torchtitan/components/quantization/mx/expert_parallel.py b/torchtitan/components/quantization/mx/expert_parallel.py new file mode 100644 index 0000000000..b9e0d1533d --- /dev/null +++ b/torchtitan/components/quantization/mx/expert_parallel.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.distributed.expert_parallel import ExpertParallel + + +class MXExpertParallel(ExpertParallel): + def __init__(self) -> None: + super().__init__() + try: + from torchao.prototype.moe_training.kernels.mxfp8.comms import ( + to_mxfp8_a2a_dequant, + ) + except ImportError as err: + raise ImportError( + "Please install torchao v0.14+ to use MXExpertParallel" + ) from err + self._a2a_dispatch_impl = to_mxfp8_a2a_dequant + self._a2a_combine_impl = to_mxfp8_a2a_dequant diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7137579f18..7e544a5929 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -752,6 +752,24 @@ class Quantize: grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM) """Quantized training config for grouped GEMMs""" + expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default" + """ + All-to-all implementation to use for the token dispatch step in expert parallelism. + - "default": Directly uses all_to_all_single with inputs/outputs in original precision. + - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, + using all_to_all_single on the quantized data and scales, then dequantizing + the outputs back to original precision. + """ + + expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default" + """ + All-to-all implementation to use for the token combine step in expert parallelism. + - "default": Directly uses all_to_all_single with inputs/outputs in original precision. + - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, + using all_to_all_single on the quantized data and scales, then dequantizing + the outputs back to original precision. + """ + @dataclass class Comm: diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index e9986b9974..cf7d9a5283 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -71,6 +71,8 @@ def __init__(self): self.output_splits = None self.input_shape = None self.permuted_indices = None + self._a2a_dispatch_impl = all_to_all_single_autograd + self._a2a_combine_impl = all_to_all_single_autograd # performing all-to-all dispatch on the input def _token_dispatch(self, mod, inputs, device_mesh): @@ -107,7 +109,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): self.output_splits = output_splits.tolist() # perform all-to-all - routed_input = all_to_all_single_autograd( + routed_input = self._a2a_dispatch_impl( routed_input, self.output_splits, self.input_splits, @@ -150,7 +152,7 @@ def _token_combine(self, mod, routed_output, device_mesh): routed_output, self.input_shape, self.permuted_indices ) - routed_output = all_to_all_single_autograd( + routed_output = self._a2a_combine_impl( routed_output, self.input_splits, self.output_splits, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8d13a3f31f..c9df982676 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -91,6 +91,7 @@ def parallelize_deepseekv3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, + job_config, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..7b5390dd19 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -17,6 +17,7 @@ RowwiseParallel, SequenceParallel, ) +from torchtitan.components.quantization.mx.expert_parallel import MXExpertParallel from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims @@ -98,6 +99,7 @@ def parallelize_llama( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, + job_config, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( @@ -436,6 +438,7 @@ def apply_fsdp( def apply_moe_ep_tp( model: nn.Module, + job_config: JobConfig, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, @@ -443,6 +446,26 @@ def apply_moe_ep_tp( ): assert ep_mesh is not None or tp_mesh is not None + EP_IMPLS = { + "default": ExpertParallel, + "mxfp8": MXExpertParallel, + } + assert ( + job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS + ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}" + assert ( + job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS + ), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}" + + logger.info( + f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}" + ) + logger.info( + f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}" + ) + + ep_class = EP_IMPLS[job_config.quantize.expert_parallel_a2a_dispatch_impl] + for transformer_block in model.layers.values(): if not transformer_block.moe_enabled: continue @@ -491,7 +514,7 @@ def apply_moe_ep_tp( elif tp_mesh is None or not etp_enabled: experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() + experts_plan = ep_class() else: experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel() diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5fa8549e9f..24c3ff3c91 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -95,6 +95,7 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, + job_config, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=(