From 3823537fc729e554fbdea7d60a1cd9f5c883f07b Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Mon, 27 Oct 2025 11:18:18 -0700 Subject: [PATCH 1/2] Add annotations to MoE [ghstack-poisoned] --- .../compiler_toolkit/deepseek_v3/parallelize.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 05503cb5e..047daa2bb 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -23,6 +23,10 @@ ) from torchtitan.tools.logging import logger +from torch.fx.traceback import annotate_fn +from torchtitan.models.moe.moe import MoE +from torchtitan.distributed.expert_parallel import ExpertParallel + @contextmanager def disable_compile(job_config: JobConfig): @@ -128,12 +132,21 @@ def wrapper_fn(args): return wrapper_fn +def annotate_model() -> None: + # annotate the MoE with dispatch, compute and combine + ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})(ExpertParallel._token_dispatch) + ExpertParallel._token_combine = annotate_fn({"EP": "combine"})(ExpertParallel._token_combine) + MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward) + + def parallelize_deepseekv3( model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + annotate_model() + # Diable torch.compile over the model in the compiler toolkit style workflow with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) From acb4235d65cd8b0e1f0fd43282f9012f4d7317f3 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Mon, 27 Oct 2025 11:24:17 -0700 Subject: [PATCH 2/2] Update on "[Compiler Toolkit] Add annotations to MoE" sample output ``` [rank0]: # Annotation: {'EP': 'dispatch'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:485 in all_to_all_single, code: tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] [rank0]: tensor_3: "i64[8]" = torch.ops._c10d_functional.all_to_all_single(num_tokens_per_expert_3, [4, 4], [4, 4], '11') [rank0]: [rank0]: # Annotation: {'EP': 'dispatch'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:136 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] [rank0]: num_tokens_per_expert_group_2: "i64[8]" = torch.ops._c10d_functional.wait_tensor(tensor_3); tensor_3 = None ``` ``` **[rank0]: # Annotation: {'EP': 'combine'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:522 in all_to_all_single_autograd, code: tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] [rank0]: slice_20: "bf16[u18 + u19, 256]" = torch.ops.aten.slice.Tensor(index_put_6, 0, 0, -1); index_put_6 = None [rank0]: all_to_all_single_14: "bf16[u16 + u17, 256]" = torch.ops._c10d_functional.all_to_all_single.default(slice_20, [_local_scalar_dense_16, _local_scalar_dense_17], [_local_scalar_dense_18, _local_scalar_dense_19], '11'); slice_20 = None [rank0]: [rank0]: # Annotation: {'EP': 'combine'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:528 in all_to_all_single_autograd, code: return _FromTorchTensor.apply(tensor) [rank0]: wait_tensor_136: "bf16[u16 + u17, 256]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_14); all_to_all_single_14 = None [rank0]: ``` [ghstack-poisoned] --- .../compiler_toolkit/deepseek_v3/parallelize.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 047daa2bb..b86f2f81d 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -14,18 +14,18 @@ from torch._guards import tracing from torch.distributed.tensor import DTensor, Replicate + +from torch.fx.traceback import annotate_fn from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.expert_parallel import ExpertParallel from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3, ) -from torchtitan.tools.logging import logger - -from torch.fx.traceback import annotate_fn from torchtitan.models.moe.moe import MoE -from torchtitan.distributed.expert_parallel import ExpertParallel +from torchtitan.tools.logging import logger @contextmanager @@ -134,8 +134,12 @@ def wrapper_fn(args): def annotate_model() -> None: # annotate the MoE with dispatch, compute and combine - ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})(ExpertParallel._token_dispatch) - ExpertParallel._token_combine = annotate_fn({"EP": "combine"})(ExpertParallel._token_combine) + ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})( + ExpertParallel._token_dispatch + ) + ExpertParallel._token_combine = annotate_fn({"EP": "combine"})( + ExpertParallel._token_combine + ) MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward)