|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import contextlib |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | + |
| 12 | +from torch._functorch.aot_autograd import ( |
| 13 | + aot_compile_joint_with_descriptors |
| 14 | +) |
| 15 | +from torch._guards import tracing, TracingContext |
| 16 | + |
| 17 | +from torch.distributed.device_mesh import DeviceMesh |
| 18 | +from torch.distributed.tensor import DTensor, Replicate |
| 19 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
| 20 | +from torchtitan.distributed import ParallelDims |
| 21 | +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp |
| 22 | + |
| 23 | +from torchtitan.experiments.simple_fsdp.deepseek_v3.model import ( |
| 24 | + SimpleFSDPDeepSeekV3Model, |
| 25 | +) |
| 26 | +from torchtitan.experiments.simple_fsdp.simple_fsdp import ( |
| 27 | + data_parallel, |
| 28 | + MixedPrecisionPolicy, |
| 29 | +) |
| 30 | + |
| 31 | +from torchtitan.models.deepseek_v3.infra.parallelize import ( |
| 32 | + apply_ac, |
| 33 | + apply_moe_ep_tp, |
| 34 | + apply_non_moe_tp, |
| 35 | +) |
| 36 | +from torchtitan.tools.logging import logger |
| 37 | + |
| 38 | +from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint, print_if_rank0 |
| 39 | + |
| 40 | + |
| 41 | +# Adapted from llama4/infra/parallelize.py |
| 42 | +def parallelize_deepseekv3( |
| 43 | + model: nn.Module, |
| 44 | + parallel_dims: ParallelDims, |
| 45 | + job_config: JobConfig, |
| 46 | +): |
| 47 | + world_mesh = parallel_dims.world_mesh |
| 48 | + # TODO: TP currently cannot handle uneven seq_len because we set |
| 49 | + # `use_local_output=True` to use plain Tensors for legacy reasons. |
| 50 | + # Need to revisit this. |
| 51 | + assert ( |
| 52 | + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 |
| 53 | + ), f""" |
| 54 | + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree |
| 55 | + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}. |
| 56 | + """ |
| 57 | + if ( |
| 58 | + job_config.parallelism.context_parallel_degree > 1 |
| 59 | + and model.model_args.use_flex_attn |
| 60 | + ): |
| 61 | + raise NotImplementedError("CP support for FlexAttention is still in progress.") |
| 62 | + |
| 63 | + if parallel_dims.tp_enabled: |
| 64 | + enable_float8_linear = "float8" in job_config.model.converters |
| 65 | + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( |
| 66 | + "rowwise", |
| 67 | + "rowwise_with_gw_hp", |
| 68 | + ) |
| 69 | + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
| 70 | + if enable_float8_tensorwise_tp: |
| 71 | + # TODO(jianiw): This branch needs to be tested and enabled |
| 72 | + raise NotImplementedError( |
| 73 | + "Currently, float8 tensorwise TP is not tested for deepseekv3" |
| 74 | + ) |
| 75 | + |
| 76 | + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) |
| 77 | + apply_non_moe_tp( |
| 78 | + model, |
| 79 | + world_mesh["tp"], |
| 80 | + loss_parallel=not job_config.parallelism.disable_loss_parallel, |
| 81 | + enable_float8_tensorwise_tp=False, |
| 82 | + use_flex_attn=use_flex_attn, |
| 83 | + ) |
| 84 | + maybe_enable_async_tp(job_config, world_mesh["tp"]) |
| 85 | + |
| 86 | + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: |
| 87 | + apply_moe_ep_tp( |
| 88 | + model, |
| 89 | + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, |
| 90 | + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, |
| 91 | + ep_tp_mesh=( |
| 92 | + world_mesh["ep", "tp"] |
| 93 | + if parallel_dims.tp_enabled |
| 94 | + and parallel_dims.ep_enabled |
| 95 | + and parallel_dims.etp_enabled |
| 96 | + else None |
| 97 | + ), |
| 98 | + etp_enabled=parallel_dims.etp_enabled, |
| 99 | + ) |
| 100 | + if job_config.activation_checkpoint.mode != "none": |
| 101 | + apply_ac(model, job_config.activation_checkpoint) |
| 102 | + |
| 103 | + mp_policy = MixedPrecisionPolicy( |
| 104 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 105 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| 106 | + ) |
| 107 | + |
| 108 | + # apply data parallel |
| 109 | + dp_mesh: DeviceMesh | None = None |
| 110 | + if ( |
| 111 | + parallel_dims.fsdp_enabled |
| 112 | + or parallel_dims.ep_enabled |
| 113 | + or parallel_dims.dp_replicate_enabled |
| 114 | + ): |
| 115 | + if parallel_dims.dp_replicate_enabled: |
| 116 | + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| 117 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| 118 | + dp_mode = "hybrid_shard" |
| 119 | + else: |
| 120 | + dp_mesh_dim_names = ("dp_replicate",) |
| 121 | + dp_mode = "replicate" |
| 122 | + else: |
| 123 | + dp_mesh_dim_names = ("dp_shard_cp",) |
| 124 | + dp_mode = "fully_shard" |
| 125 | + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] |
| 126 | + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP |
| 127 | + dp_mod_ep_mesh_dim_names = [] |
| 128 | + if parallel_dims.ep_enabled: |
| 129 | + if parallel_dims.dp_replicate_enabled: |
| 130 | + dp_mod_ep_mesh_dim_names.append("dp_replicate") |
| 131 | + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") |
| 132 | + dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] |
| 133 | + for _, transformer_block in model.layers.items(): |
| 134 | + if transformer_block.moe_enabled and parallel_dims.ep_enabled: |
| 135 | + experts_shard_dim = 0 |
| 136 | + assert dp_mod_ep_mesh is not None |
| 137 | + assert hasattr(transformer_block, "moe") |
| 138 | + if ( |
| 139 | + dp_mod_ep_mesh.size() * parallel_dims.ep |
| 140 | + > transformer_block.moe.experts.num_experts |
| 141 | + ): |
| 142 | + experts_shard_dim = 1 |
| 143 | + transformer_block.moe.experts = data_parallel( |
| 144 | + transformer_block.moe.experts, |
| 145 | + dp_mod_ep_mesh, |
| 146 | + dp_mode, |
| 147 | + ac_mode=job_config.activation_checkpoint.mode, |
| 148 | + mp_policy=mp_policy, |
| 149 | + shard_dim=experts_shard_dim, |
| 150 | + ) |
| 151 | + # TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp |
| 152 | + # transformer_block.moe.experts.set_gradient_divide_factor( |
| 153 | + # parallel_dims.fsdp_gradient_divide_factor, |
| 154 | + # ) |
| 155 | + model = data_parallel( |
| 156 | + model, |
| 157 | + dp_mesh, |
| 158 | + dp_mode, |
| 159 | + ac_mode=job_config.activation_checkpoint.mode, |
| 160 | + mp_policy=mp_policy, |
| 161 | + ) |
| 162 | + logger.info( |
| 163 | + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode |
| 164 | + ) |
| 165 | + if job_config.compile.enable: |
| 166 | + # TODO: CompiledModule should take sample input as well, so that we can |
| 167 | + # compile ahead of time. |
| 168 | + model = CompiledModule(model, parallel_dims) |
| 169 | + |
| 170 | + return model |
| 171 | + |
| 172 | + |
| 173 | +class CompiledModule(torch.nn.Module): |
| 174 | + def __init__(self, inner: torch.nn.Module, parallel_dims, **overrides): |
| 175 | + super().__init__() |
| 176 | + self.inner = inner # register as submodule |
| 177 | + self.parallel_dims = parallel_dims |
| 178 | + |
| 179 | + self.joint_graph_module = None |
| 180 | + self._overrides = overrides # for custom hooks |
| 181 | + |
| 182 | + def __getattr__(self, name): |
| 183 | + # check overrides |
| 184 | + if "_overrides" in self.__dict__ and name in self._overrides: |
| 185 | + return self._overrides[name] |
| 186 | + try: |
| 187 | + # let nn.Module handle registered stuff |
| 188 | + return super().__getattr__(name) |
| 189 | + except AttributeError: |
| 190 | + # fallback to inner model |
| 191 | + return getattr(self.inner, name) |
| 192 | + |
| 193 | + def __setattr__(self, name, value): |
| 194 | + if "_overrides" in self.__dict__ and name in self._overrides: |
| 195 | + self._overrides[name] = value |
| 196 | + else: |
| 197 | + super().__setattr__(name, value) |
| 198 | + |
| 199 | + def __delattr__(self, name): |
| 200 | + if "_overrides" in self.__dict__ and name in self._overrides: |
| 201 | + del self._overrides[name] |
| 202 | + else: |
| 203 | + super().__delattr__(name) |
| 204 | + |
| 205 | + def forward(self, *args, **kwargs): |
| 206 | + assert "forward" not in self._overrides, "forward cannot be overridden" |
| 207 | + dt_args = tuple( |
| 208 | + DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()]) |
| 209 | + for arg in args |
| 210 | + ) |
| 211 | + if self.joint_graph_module is None: |
| 212 | + self.joint_graph_module = joint_graph_builder( |
| 213 | + self.inner, *dt_args, **kwargs |
| 214 | + ) |
| 215 | + |
| 216 | + # calling the line below returns control to torchtitan's runner |
| 217 | + # letting it call the backward, and optimizer. |
| 218 | + |
| 219 | + # TODO: add support for kwargs |
| 220 | + return self.joint_graph_module(args) |
| 221 | + |
| 222 | + |
| 223 | +def joint_graph_builder(model, *inputs, **kwargs): |
| 224 | + assert isinstance(inputs, tuple) |
| 225 | + for input in inputs: |
| 226 | + assert isinstance(input, DTensor) |
| 227 | + |
| 228 | + # get joint graph |
| 229 | + ( |
| 230 | + joint_with_descriptors, |
| 231 | + tracing_context, |
| 232 | + ) = export_joint(model, inputs) |
| 233 | + |
| 234 | + def fw_compiler(gm: torch.fx.GraphModule, example_inputs): |
| 235 | + print_if_rank0("fwd_gm:") |
| 236 | + print_if_rank0(gm.print_readable(print_output=False)) |
| 237 | + |
| 238 | + # print_if_rank0("After compiler:") |
| 239 | + # print_if_rank0(gm.print_readable(print_output=False)) |
| 240 | + return gm |
| 241 | + |
| 242 | + def bw_compiler(gm: torch.fx.GraphModule, example_inputs): |
| 243 | + print_if_rank0("bwd_gm:") |
| 244 | + print_if_rank0(gm.print_readable(print_output=False)) |
| 245 | + |
| 246 | + # print_if_rank0("After compiler:") |
| 247 | + # print_if_rank0(gm.print_readable(print_output=False)) |
| 248 | + return gm |
| 249 | + |
| 250 | + with tracing(tracing_context): |
| 251 | + fn = aot_compile_joint_with_descriptors( |
| 252 | + joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler |
| 253 | + ) |
| 254 | + |
| 255 | + def wrapper_fn(args): |
| 256 | + input = [ |
| 257 | + *model.parameters(), |
| 258 | + *model.buffers(), |
| 259 | + *args, |
| 260 | + ] |
| 261 | + return fn(*input) |
| 262 | + |
| 263 | + return wrapper_fn |
0 commit comments