|
14 | 14 | from torch._guards import tracing |
15 | 15 |
|
16 | 16 | from torch.distributed.tensor import DTensor, Replicate |
| 17 | + |
| 18 | +from torch.fx.traceback import annotate_fn |
17 | 19 | from torchtitan.config import JobConfig |
18 | 20 | from torchtitan.distributed import ParallelDims |
| 21 | +from torchtitan.distributed.expert_parallel import ExpertParallel |
19 | 22 |
|
20 | 23 | from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint |
21 | 24 | from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( |
22 | 25 | parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3, |
23 | 26 | ) |
| 27 | +from torchtitan.models.moe.moe import MoE |
24 | 28 | from torchtitan.tools.logging import logger |
25 | 29 |
|
26 | 30 |
|
@@ -128,12 +132,25 @@ def wrapper_fn(args): |
128 | 132 | return wrapper_fn |
129 | 133 |
|
130 | 134 |
|
| 135 | +def annotate_model() -> None: |
| 136 | + # annotate the MoE with dispatch, compute and combine |
| 137 | + ExpertParallel._token_dispatch = annotate_fn({"EP": "dispatch"})( |
| 138 | + ExpertParallel._token_dispatch |
| 139 | + ) |
| 140 | + ExpertParallel._token_combine = annotate_fn({"EP": "combine"})( |
| 141 | + ExpertParallel._token_combine |
| 142 | + ) |
| 143 | + MoE.forward = annotate_fn({"EP": "compute"})(MoE.forward) |
| 144 | + |
| 145 | + |
131 | 146 | def parallelize_deepseekv3( |
132 | 147 | model: nn.Module, |
133 | 148 | parallel_dims: ParallelDims, |
134 | 149 | job_config: JobConfig, |
135 | 150 | ) -> CompiledModule: |
136 | 151 |
|
| 152 | + annotate_model() |
| 153 | + |
137 | 154 | # Diable torch.compile over the model in the compiler toolkit style workflow |
138 | 155 | with disable_compile(job_config): |
139 | 156 | model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) |
|
0 commit comments