Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchtitan/components/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ def _validate(job_config: JobConfig):

# Import to register quantization modules as ModelConverter
# (imports down here to avoid circular imports with QuantizationConverter)
import torchtitan.components.quantization.float8 # noqa: F401
import torchtitan.components.quantization.mx # noqa: F401
import torchtitan.components.quantization.float8.converters # noqa: F401
import torchtitan.components.quantization.mx.converters # noqa: F401
5 changes: 5 additions & 0 deletions torchtitan/components/quantization/float8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FP8_GROUP_ALIGNMENT_SIZE,
QuantizationConverter,
)
from torchtitan.components.quantization.utils import module_filter_fn

from torchtitan.config.job_config import Float8Linear, JobConfig
from torchtitan.distributed import ParallelDims
Expand All @@ -19,8 +20,6 @@
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability

from .utils import module_filter_fn

AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"


Expand Down
5 changes: 5 additions & 0 deletions torchtitan/components/quantization/mx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MXFP8_GROUP_ALIGNMENT_SIZE,
QuantizationConverter,
)
from torchtitan.components.quantization.utils import module_filter_fn

from torchtitan.config.job_config import JobConfig
from torchtitan.distributed import ParallelDims
Expand All @@ -21,8 +22,6 @@
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability

from .utils import module_filter_fn


class MXLinearConverter(QuantizationConverter):
"""Converts the linear layers of `model` to `MXLinear`."""
Expand Down
22 changes: 22 additions & 0 deletions torchtitan/components/quantization/mx/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.distributed.expert_parallel import ExpertParallel


class MXExpertParallel(ExpertParallel):
def __init__(self) -> None:
super().__init__()
try:
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
to_mxfp8_a2a_dequant,
)
except ImportError as err:
raise ImportError(
"Please install torchao v0.14+ to use MXExpertParallel"
) from err
self._a2a_dispatch_impl = to_mxfp8_a2a_dequant
self._a2a_combine_impl = to_mxfp8_a2a_dequant
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait you are changing both together, not independently

18 changes: 18 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@ class Quantize:
grouped_mm: QuantizedGroupedMM = field(default_factory=QuantizedGroupedMM)
"""Quantized training config for grouped GEMMs"""

expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not intuitive to let Quantize to dictate what a2a impl should be. Instead we should let Quantize to override the default setting.

I'd suggest we rename this to override_ep_a2a_dispatch. Since we only have mxfp8 for now, you can go with bool, or you could make it Literal | None.

"""
All-to-all implementation to use for the token dispatch step in expert parallelism.
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
using all_to_all_single on the quantized data and scales, then dequantizing
the outputs back to original precision.
"""

expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default"
"""
All-to-all implementation to use for the token combine step in expert parallelism.
- "default": Directly uses all_to_all_single with inputs/outputs in original precision.
- "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8,
using all_to_all_single on the quantized data and scales, then dequantizing
the outputs back to original precision.
"""


@dataclass
class Comm:
Expand Down
6 changes: 4 additions & 2 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self):
self.output_splits = None
self.input_shape = None
self.permuted_indices = None
self._a2a_dispatch_impl = all_to_all_single_autograd
self._a2a_combine_impl = all_to_all_single_autograd

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
Expand Down Expand Up @@ -107,7 +109,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
self.output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input = self._a2a_dispatch_impl(
routed_input,
self.output_splits,
self.input_splits,
Expand Down Expand Up @@ -150,7 +152,7 @@ def _token_combine(self, mod, routed_output, device_mesh):
routed_output, self.input_shape, self.permuted_indices
)

routed_output = all_to_all_single_autograd(
routed_output = self._a2a_combine_impl(
routed_output,
self.input_splits,
self.output_splits,
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def parallelize_deepseekv3(
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
job_config,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
Expand Down
25 changes: 24 additions & 1 deletion torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RowwiseParallel,
SequenceParallel,
)
from torchtitan.components.quantization.mx.expert_parallel import MXExpertParallel
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config.job_config import Compile as CompileConfig
from torchtitan.distributed import NoParallel, ParallelDims
Expand Down Expand Up @@ -98,6 +99,7 @@ def parallelize_llama(
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
job_config,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
Expand Down Expand Up @@ -436,13 +438,34 @@ def apply_fsdp(

def apply_moe_ep_tp(
model: nn.Module,
job_config: JobConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only send in job_config.quantize instead of the whole job_config
Also let's make it a non-positional arg to avoid unnecessary BC breaking

tp_mesh: DeviceMesh | None,
ep_mesh: DeviceMesh | None,
ep_tp_mesh: DeviceMesh | None,
etp_enabled: bool,
):
assert ep_mesh is not None or tp_mesh is not None

EP_IMPLS = {
"default": ExpertParallel,
"mxfp8": MXExpertParallel,
}
Comment on lines +449 to +452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some if-else on quantize_config is enough, no need to have this EP_IMPLS variable here.

assert (
job_config.quantize.expert_parallel_a2a_dispatch_impl in EP_IMPLS
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_dispatch_impl}, must be one of {EP_IMPLS.keys()}"
assert (
job_config.quantize.expert_parallel_a2a_combine_impl in EP_IMPLS
), f"Unknown EP impl: {job_config.quantize.expert_parallel_a2a_combine_impl}, must be one of {EP_IMPLS.keys()}"
Comment on lines +453 to +458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do this in quantized EP class, if you adopt the Literal version of config


logger.info(
f"Using all-to-all dispatch implementation: {job_config.quantize.expert_parallel_a2a_dispatch_impl}"
)
logger.info(
f"Using all-to-all combine implementation: {job_config.quantize.expert_parallel_a2a_combine_impl}"
)
Comment on lines +460 to +465
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the comment above, we should print these override info only when they are overridden during quantization.


ep_class = EP_IMPLS[job_config.quantize.expert_parallel_a2a_dispatch_impl]

for transformer_block in model.layers.values():
if not transformer_block.moe_enabled:
continue
Expand Down Expand Up @@ -491,7 +514,7 @@ def apply_moe_ep_tp(
elif tp_mesh is None or not etp_enabled:
experts_mesh = ep_mesh
# input / output sharding on the batch / tokens dim
experts_plan = ExpertParallel()
experts_plan = ep_class()
else:
experts_mesh = ep_tp_mesh
experts_plan = ExpertTensorParallel()
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def parallelize_qwen3(
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
job_config,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
Expand Down
Loading