|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import time |
| 8 | + |
| 9 | +import torch |
| 10 | +from autoparallel.api import AutoParallel |
| 11 | +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing |
| 12 | + |
| 13 | +from torch.distributed.tensor.placement_types import Shard |
| 14 | +from torchtitan.config import JobConfig |
| 15 | +from torchtitan.distributed import ParallelDims |
| 16 | + |
| 17 | +from torchtitan.tools.logging import logger |
| 18 | + |
| 19 | + |
| 20 | +# TODO: Autoparallel should transparently wrap the original nn.Module |
| 21 | +# but I don't know how to do that. |
| 22 | +def set_torchtitan_fields(orig, new): |
| 23 | + assert isinstance(new.layers, torch.nn.ModuleDict) |
| 24 | + for block in new.layers.values(): |
| 25 | + block.moe_enabled = hasattr(block, "moe") |
| 26 | + |
| 27 | + |
| 28 | +def parallelize_deepseekv3( |
| 29 | + model, |
| 30 | + parallel_dims: ParallelDims, |
| 31 | + job_config: JobConfig, |
| 32 | +): |
| 33 | + """ |
| 34 | + Apply Autoparallel to the model |
| 35 | +
|
| 36 | + NOTE: The passed-in model preferably should be on meta device. Otherwise, |
| 37 | + the model must fit on GPU or CPU memory. |
| 38 | + """ |
| 39 | + |
| 40 | + # TODO(whc) |
| 41 | + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering |
| 42 | + torch._inductor.config.force_disable_caches = True |
| 43 | + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. |
| 44 | + torch._inductor.config.allow_buffer_reuse = False |
| 45 | + |
| 46 | + # allow configuring inductor comms optimizations from torchtitan commandline |
| 47 | + configure_inductor_for_autobucketing( |
| 48 | + job_config.experimental.comms_bucket_reorder_strategy |
| 49 | + ) |
| 50 | + |
| 51 | + world_mesh = parallel_dims.world_mesh |
| 52 | + |
| 53 | + # Update me when changing dsv3.py |
| 54 | + assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" |
| 55 | + |
| 56 | + # Provide AP MoE with mesh |
| 57 | + for layer in model.layers.values(): |
| 58 | + if layer.moe_enabled: |
| 59 | + layer.moe.mesh = world_mesh |
| 60 | + layer.moe.axis_name = "dp_shard_in_ep" |
| 61 | + |
| 62 | + def input_fn(): |
| 63 | + global_batch_size = job_config.training.global_batch_size |
| 64 | + if global_batch_size < 0: |
| 65 | + # This global batch size results in 1 gradient accumulation |
| 66 | + # step. |
| 67 | + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard |
| 68 | + global_batch_size = job_config.training.local_batch_size * dp_degree |
| 69 | + return ( |
| 70 | + torch.randint( |
| 71 | + 0, |
| 72 | + model.model_args.vocab_size, |
| 73 | + (global_batch_size, job_config.training.seq_len), |
| 74 | + device=torch.device("cuda"), |
| 75 | + ), |
| 76 | + ) |
| 77 | + |
| 78 | + should_compile = job_config.compile.enable |
| 79 | + if should_compile: |
| 80 | + # TODO: support more options in AP API |
| 81 | + assert job_config.compile.components == ["model"] |
| 82 | + assert job_config.compile.backend == "inductor" |
| 83 | + |
| 84 | + mp_policy = None |
| 85 | + with AutoParallel( |
| 86 | + model, |
| 87 | + input_fn, |
| 88 | + world_mesh, |
| 89 | + mp_policy=mp_policy, |
| 90 | + compile=should_compile, |
| 91 | + dynamic=True, |
| 92 | + ) as autop: |
| 93 | + autop.add_parameter_memory_constraint(low=None, high=None) |
| 94 | + |
| 95 | + x_sharding = (Shard(0), Shard(0)) |
| 96 | + loss_parallel_enabled = ( |
| 97 | + parallel_dims.tp_enabled |
| 98 | + and not job_config.parallelism.disable_loss_parallel |
| 99 | + ) |
| 100 | + assert not loss_parallel_enabled |
| 101 | + autop.add_input_constraints([x_sharding]) |
| 102 | + autop.add_output_constraints([x_sharding]) |
| 103 | + t0 = time.time() |
| 104 | + sharding_placement = autop.optimize_placement() |
| 105 | + t1 = time.time() |
| 106 | + logger.info(f"AutoParallel took {t1 - t0} seconds") |
| 107 | + parallel_mod = autop.apply_placement(sharding_placement) |
| 108 | + |
| 109 | + set_torchtitan_fields(model, parallel_mod) |
| 110 | + |
| 111 | + if loss_parallel_enabled: |
| 112 | + |
| 113 | + # current PyTorch's implementation of loss parallel assumes |
| 114 | + # that the DTensor has a 1d device mesh. This is not true |
| 115 | + # in our case, but we can work around it by adding |
| 116 | + # casting the output to a DTensor on a 1d device mesh. |
| 117 | + # We should just use AutoParallel to do this for us, but |
| 118 | + # it would require putting the loss inside the model as well |
| 119 | + def _return_as_dtensor_for_loss_parallel(module, args, output): |
| 120 | + return torch.distributed.tensor.DTensor.from_local( |
| 121 | + output, world_mesh["tp"], (Shard(2),) |
| 122 | + ) |
| 123 | + |
| 124 | + # not keeping a reference to the hook, don't plan on |
| 125 | + # removing it at any point |
| 126 | + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) |
| 127 | + |
| 128 | + _preserve_moe_attributes(model, parallel_mod) |
| 129 | + |
| 130 | + return parallel_mod |
| 131 | + |
| 132 | + |
| 133 | +def _preserve_moe_attributes(original_model, parallel_model): |
| 134 | + """ |
| 135 | + Preserve MoE custom attributes from the original model to the parallel model. |
| 136 | + This is only needed for attributes that aren't used in the graph, so they aren't |
| 137 | + lifted as graph inputs and fetched by the pre-graph runtime wrapper. |
| 138 | +
|
| 139 | + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify |
| 140 | + this block as a moe block. This should be safe as they are read-only. |
| 141 | + """ |
| 142 | + |
| 143 | + def get_moe_modules(model): |
| 144 | + """Extract all MoE modules from the model.""" |
| 145 | + moe_modules = [] |
| 146 | + if hasattr(model, "layers"): |
| 147 | + if isinstance(model.layers, torch.nn.ModuleDict): |
| 148 | + # regular torchtitan structure |
| 149 | + blocks = model.layers.values() |
| 150 | + else: |
| 151 | + # autoparallel might change structure |
| 152 | + blocks = ( |
| 153 | + model.layers.children() if hasattr(model.layers, "children") else [] |
| 154 | + ) |
| 155 | + |
| 156 | + for block in blocks: |
| 157 | + if ( |
| 158 | + hasattr(block, "moe_enabled") |
| 159 | + and block.moe_enabled |
| 160 | + and hasattr(block, "moe") |
| 161 | + ): |
| 162 | + moe_modules.append(block.moe) |
| 163 | + elif hasattr(block, "moe"): # fallback for autoparallel |
| 164 | + moe_modules.append(block.moe) |
| 165 | + return moe_modules |
| 166 | + |
| 167 | + original_moe_modules = get_moe_modules(original_model) |
| 168 | + parallel_moe_modules = get_moe_modules(parallel_model) |
| 169 | + |
| 170 | + # Copy custom attributes from original to parallel MoE modules |
| 171 | + # This is fine to do since these attributes are read only |
| 172 | + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): |
| 173 | + if hasattr(orig_moe, "moe_enabled"): |
| 174 | + par_moe.load_balance_coeff = orig_moe.load_balance_coeff |
| 175 | + |
| 176 | + # Copy load_balance_coeff |
| 177 | + if hasattr(orig_moe, "load_balance_coeff"): |
| 178 | + par_moe.load_balance_coeff = orig_moe.load_balance_coeff |
0 commit comments