From 573624016144705905ce6d88d5362e31ecd0930b Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 8 Dec 2025 21:07:57 -0800 Subject: [PATCH] [Autoparallel] Add local_map variant of DSv3 and 2D mesh AP stack-info: PR: https://github.com/pytorch/torchtitan/pull/2129, branch: xmfan/stack/7 --- torchtitan/experiments/__init__.py | 1 + torchtitan/experiments/autoparallel/README.md | 6 + .../deepseek_v3/parallelize_deepseekv3.py | 2 - .../local_map_deepseek_v3/__init__.py | 57 ++++++ .../local_map_deepseek_v3/args.py | 49 +++++ .../local_map_deepseek_v3/model.py | 18 ++ .../parallelize_deepseekv3.py | 182 ++++++++++++++++++ 7 files changed, 313 insertions(+), 2 deletions(-) create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py create mode 100644 torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 7d7f4da41a..10f9030c1d 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -15,5 +15,6 @@ "transformers_modeling_backend", "autoparallel.llama3", "autoparallel.deepseek_v3", + "autoparallel.local_map_deepseek_v3", ] ) diff --git a/torchtitan/experiments/autoparallel/README.md b/torchtitan/experiments/autoparallel/README.md index 3be86b9bc3..570237b4d9 100644 --- a/torchtitan/experiments/autoparallel/README.md +++ b/torchtitan/experiments/autoparallel/README.md @@ -17,3 +17,9 @@ Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://githu **DeepSeekv3** `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` + +**DeepSeekv3 local_map** + +This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep. + +`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2` diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index 0f718a389b..80dfcac9a3 100644 --- a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -257,8 +257,6 @@ def set_torchtitan_fields(orig, new): block.moe_enabled = hasattr(block, "moe") -# Run workflow with: -# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_autoparallel def parallelize_deepseekv3( model, parallel_dims: ParallelDims, diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py new file mode 100644 index 0000000000..fdd8435ebc --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .args import DeepSeekV3ModelArgs, get_sample_config + +from .model import DeepSeekV3Model +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_model_args() -> DeepSeekV3ModelArgs: + model_args = copy.deepcopy(deepseekv3_args) + # TODO: Align configs between AP and Titan + for config in model_args.keys(): + # Just override the configs + override = get_sample_config() + override.update_from_config = model_args[config].update_from_config + override.get_nparams_and_flops = model_args[config].get_nparams_and_flops + model_args[config] = override + + return model_args + + +def get_train_spec() -> TrainSpec: + model_args = get_model_args() + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py new file mode 100644 index 0000000000..7f1f84f45a --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs, + MoEArgs as _MoEArgs, +) +from torchtitan.protocols.model import BaseModelArgs + + +# Need to share same base class with torchtitan models +@dataclass +class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs): + pass + + +def get_sample_config() -> DeepSeekV3ModelArgs: + return DeepSeekV3ModelArgs( + vocab_size=2048, + max_seq_len=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=4, + n_dense_layers=0, + n_heads=16, + moe_args=_MoEArgs( + num_experts=4, + num_shared_experts=2, + top_k=2, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=None, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py new file mode 100644 index 0000000000..f4915fb708 --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepSeekV3ModelArgs + + +# Need to share same base class with torchtitan models +class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__(model_args) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..eb400484f6 --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + # Update me when changing dsv3.py + assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" + assert world_mesh.mesh_dim_names == ( + "dp_shard_mod_ep", + "dp_shard_in_ep", + ), "Current setup assumes these specific meshes" + + # Provide AP MoE with mesh + for layer in model.layers.values(): + if layer.moe_enabled: + layer.moe.mesh = world_mesh + layer.moe.axis_name = "dp_shard_in_ep" + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + should_compile = job_config.compile.enable + if should_compile: + # TODO: support more options in AP API + assert job_config.compile.components == ["model"] + assert job_config.compile.backend == "inductor" + + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=should_compile, + dynamic=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Shard(0)) + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + assert not loss_parallel_enabled + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff