diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 05503cb5e..b86f2f81d 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -14,13 +14,17 @@ from torch._guards import tracing from torch.distributed.tensor import DTensor, Replicate + +from torch.fx.traceback import annotate_fn from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.expert_parallel import ExpertParallel from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3, ) +from torchtitan.models.moe.moe import MoE from torchtitan.tools.logging import logger @@ -128,12 +132,25 @@ def wrapper_fn(args): return wrapper_fn +def annotate_model() -> None: + # annotate the MoE with dispatch, compute and combine + ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})( + ExpertParallel._token_dispatch + ) + ExpertParallel._token_combine = annotate_fn({"EP": "combine"})( + ExpertParallel._token_combine + ) + MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward) + + def parallelize_deepseekv3( model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + annotate_model() + # Diable torch.compile over the model in the compiler toolkit style workflow with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)