Skip to content

Commit cfae061

Browse files
[Compiler Toolkit] Add annotations to MoE (#1937)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1937 * #1906 sample output ``` [rank0]: # Annotation: {'EP': 'dispatch'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:485 in all_to_all_single, code: tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] [rank0]: tensor_3: "i64[8]" = torch.ops._c10d_functional.all_to_all_single(num_tokens_per_expert_3, [4, 4], [4, 4], '11') [rank0]: [rank0]: # Annotation: {'EP': 'dispatch'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:136 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] [rank0]: num_tokens_per_expert_group_2: "i64[8]" = torch.ops._c10d_functional.wait_tensor(tensor_3); tensor_3 = None ``` ``` **[rank0]: # Annotation: {'EP': 'combine'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:522 in all_to_all_single_autograd, code: tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] [rank0]: slice_20: "bf16[u18 + u19, 256]" = torch.ops.aten.slice.Tensor(index_put_6, 0, 0, -1); index_put_6 = None [rank0]: all_to_all_single_14: "bf16[u16 + u17, 256]" = torch.ops._c10d_functional.all_to_all_single.default(slice_20, [_local_scalar_dense_16, _local_scalar_dense_17], [_local_scalar_dense_18, _local_scalar_dense_19], '11'); slice_20 = None [rank0]: [rank0]: # Annotation: {'EP': 'combine'} File: /data/users/bahuang/pytorch/torch/distributed/_functional_collectives.py:528 in all_to_all_single_autograd, code: return _FromTorchTensor.apply(tensor) [rank0]: wait_tensor_136: "bf16[u16 + u17, 256]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_14); all_to_all_single_14 = None [rank0]: ```
1 parent 1b35808 commit cfae061

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
from torch._guards import tracing
1515

1616
from torch.distributed.tensor import DTensor, Replicate
17+
18+
from torch.fx.traceback import annotate_fn
1719
from torchtitan.config import JobConfig
1820
from torchtitan.distributed import ParallelDims
21+
from torchtitan.distributed.expert_parallel import ExpertParallel
1922

2023
from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint
2124
from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import (
2225
parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3,
2326
)
27+
from torchtitan.models.moe.moe import MoE
2428
from torchtitan.tools.logging import logger
2529

2630

@@ -128,12 +132,25 @@ def wrapper_fn(args):
128132
return wrapper_fn
129133

130134

135+
def annotate_model() -> None:
136+
# annotate the MoE with dispatch, compute and combine
137+
ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})(
138+
ExpertParallel._token_dispatch
139+
)
140+
ExpertParallel._token_combine = annotate_fn({"EP": "combine"})(
141+
ExpertParallel._token_combine
142+
)
143+
MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward)
144+
145+
131146
def parallelize_deepseekv3(
132147
model: nn.Module,
133148
parallel_dims: ParallelDims,
134149
job_config: JobConfig,
135150
) -> CompiledModule:
136151

152+
annotate_model()
153+
137154
# Diable torch.compile over the model in the compiler toolkit style workflow
138155
with disable_compile(job_config):
139156
model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config)

0 commit comments

Comments
 (0)