-
Notifications
You must be signed in to change notification settings - Fork 643
[Autoparallel] Add local_map variant of DSv3 and 2D mesh AP #2129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,3 +17,9 @@ Requires installing [[email protected]:meta-pytorch/autoparallel.git](https://githu | |
| **DeepSeekv3** | ||
|
|
||
| `CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` | ||
|
|
||
| **DeepSeekv3 local_map** | ||
|
|
||
| This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep. | ||
|
|
||
| `NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| import copy | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.distributed.pipeline_parallel import pipeline_llm | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
|
|
||
| from torchtitan.models.deepseek_v3 import deepseekv3_args | ||
| from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( | ||
| DeepSeekV3StateDictAdapter, | ||
| ) | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .args import DeepSeekV3ModelArgs, get_sample_config | ||
|
|
||
| from .model import DeepSeekV3Model | ||
| from .parallelize_deepseekv3 import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_model_args() -> DeepSeekV3ModelArgs: | ||
| model_args = copy.deepcopy(deepseekv3_args) | ||
| # TODO: Align configs between AP and Titan | ||
| for config in model_args.keys(): | ||
| # Just override the configs | ||
| override = get_sample_config() | ||
| override.update_from_config = model_args[config].update_from_config | ||
| override.get_nparams_and_flops = model_args[config].get_nparams_and_flops | ||
| model_args[config] = override | ||
|
|
||
| return model_args | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| model_args = get_model_args() | ||
|
|
||
| return TrainSpec( | ||
| model_cls=DeepSeekV3Model, | ||
| model_args=model_args, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llm, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_text_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| state_dict_adapter=DeepSeekV3StateDictAdapter, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| from autoparallel._testing.models.dsv3 import ( | ||
| DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs, | ||
| MoEArgs as _MoEArgs, | ||
| ) | ||
| from torchtitan.protocols.model import BaseModelArgs | ||
|
|
||
|
|
||
| # Need to share same base class with torchtitan models | ||
| @dataclass | ||
| class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs): | ||
| pass | ||
|
|
||
|
|
||
| def get_sample_config() -> DeepSeekV3ModelArgs: | ||
| return DeepSeekV3ModelArgs( | ||
| vocab_size=2048, | ||
| max_seq_len=2048, | ||
| dim=256, | ||
| inter_dim=1024, | ||
| moe_inter_dim=256, | ||
| n_layers=4, | ||
| n_dense_layers=0, | ||
| n_heads=16, | ||
| moe_args=_MoEArgs( | ||
| num_experts=4, | ||
| num_shared_experts=2, | ||
| top_k=2, | ||
| score_func="softmax", | ||
| route_norm=False, | ||
| score_before_experts=False, | ||
| mesh=None, | ||
| ), | ||
| q_lora_rank=0, | ||
| kv_lora_rank=512, | ||
| qk_nope_head_dim=128, | ||
| qk_rope_head_dim=64, | ||
| v_head_dim=128, | ||
| mscale=0.70, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
|
||
| from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model | ||
| from torchtitan.protocols.train_spec import ModelProtocol | ||
|
|
||
| from .args import DeepSeekV3ModelArgs | ||
|
|
||
|
|
||
| # Need to share same base class with torchtitan models | ||
| class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): | ||
| def __init__(self, model_args: DeepSeekV3ModelArgs): | ||
| super().__init__(model_args) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import time | ||
|
|
||
| import torch | ||
| from autoparallel.api import AutoParallel | ||
| from autoparallel.auto_bucketing import configure_inductor_for_autobucketing | ||
|
|
||
| from torch.distributed.tensor.placement_types import Shard | ||
| from torchtitan.config import JobConfig | ||
| from torchtitan.distributed import ParallelDims | ||
|
|
||
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| # TODO: Autoparallel should transparently wrap the original nn.Module | ||
| # but I don't know how to do that. | ||
| def set_torchtitan_fields(orig, new): | ||
| assert isinstance(new.layers, torch.nn.ModuleDict) | ||
| for block in new.layers.values(): | ||
| block.moe_enabled = hasattr(block, "moe") | ||
|
|
||
|
|
||
| def parallelize_deepseekv3( | ||
| model, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| ): | ||
| """ | ||
| Apply Autoparallel to the model | ||
|
|
||
| NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
| the model must fit on GPU or CPU memory. | ||
| """ | ||
|
|
||
| # TODO(whc) | ||
| # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering | ||
| torch._inductor.config.force_disable_caches = True | ||
| # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. | ||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
|
||
| # allow configuring inductor comms optimizations from torchtitan commandline | ||
| configure_inductor_for_autobucketing( | ||
| job_config.experimental.comms_bucket_reorder_strategy | ||
| ) | ||
|
|
||
| world_mesh = parallel_dims.world_mesh | ||
|
|
||
| # Update me when changing dsv3.py | ||
| assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" | ||
| assert world_mesh.mesh_dim_names == ( | ||
| "dp_shard_mod_ep", | ||
| "dp_shard_in_ep", | ||
| ), "Current setup assumes these specific meshes" | ||
|
|
||
| # Provide AP MoE with mesh | ||
| for layer in model.layers.values(): | ||
| if layer.moe_enabled: | ||
| layer.moe.mesh = world_mesh | ||
| layer.moe.axis_name = "dp_shard_in_ep" | ||
|
|
||
| def input_fn(): | ||
| global_batch_size = job_config.training.global_batch_size | ||
| if global_batch_size < 0: | ||
| # This global batch size results in 1 gradient accumulation | ||
| # step. | ||
| dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
| global_batch_size = job_config.training.local_batch_size * dp_degree | ||
| return ( | ||
| torch.randint( | ||
| 0, | ||
| model.model_args.vocab_size, | ||
| (global_batch_size, job_config.training.seq_len), | ||
| device=torch.device("cuda"), | ||
| ), | ||
| ) | ||
|
|
||
| should_compile = job_config.compile.enable | ||
| if should_compile: | ||
| # TODO: support more options in AP API | ||
| assert job_config.compile.components == ["model"] | ||
| assert job_config.compile.backend == "inductor" | ||
|
|
||
| mp_policy = None | ||
| with AutoParallel( | ||
| model, | ||
| input_fn, | ||
| world_mesh, | ||
| mp_policy=mp_policy, | ||
| compile=should_compile, | ||
| dynamic=True, | ||
| ) as autop: | ||
| autop.add_parameter_memory_constraint(low=None, high=None) | ||
|
|
||
| x_sharding = (Shard(0), Shard(0)) | ||
| loss_parallel_enabled = ( | ||
| parallel_dims.tp_enabled | ||
| and not job_config.parallelism.disable_loss_parallel | ||
| ) | ||
| assert not loss_parallel_enabled | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of failing here why not disable loss parallel and give a warning?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warnings are easy to miss, you won't know loss parallel is disable if you don't spot it before titan starts dumping its logs |
||
| autop.add_input_constraints([x_sharding]) | ||
| autop.add_output_constraints([x_sharding]) | ||
| t0 = time.time() | ||
| sharding_placement = autop.optimize_placement() | ||
| t1 = time.time() | ||
| logger.info(f"AutoParallel took {t1 - t0} seconds") | ||
| parallel_mod = autop.apply_placement(sharding_placement) | ||
|
|
||
| set_torchtitan_fields(model, parallel_mod) | ||
|
|
||
| if loss_parallel_enabled: | ||
|
|
||
| # current PyTorch's implementation of loss parallel assumes | ||
| # that the DTensor has a 1d device mesh. This is not true | ||
| # in our case, but we can work around it by adding | ||
| # casting the output to a DTensor on a 1d device mesh. | ||
| # We should just use AutoParallel to do this for us, but | ||
| # it would require putting the loss inside the model as well | ||
| def _return_as_dtensor_for_loss_parallel(module, args, output): | ||
| return torch.distributed.tensor.DTensor.from_local( | ||
| output, world_mesh["tp"], (Shard(2),) | ||
| ) | ||
|
|
||
| # not keeping a reference to the hook, don't plan on | ||
| # removing it at any point | ||
| parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) | ||
|
|
||
| _preserve_moe_attributes(model, parallel_mod) | ||
|
|
||
| return parallel_mod | ||
|
|
||
|
|
||
| def _preserve_moe_attributes(original_model, parallel_model): | ||
| """ | ||
| Preserve MoE custom attributes from the original model to the parallel model. | ||
| This is only needed for attributes that aren't used in the graph, so they aren't | ||
| lifted as graph inputs and fetched by the pre-graph runtime wrapper. | ||
|
|
||
| `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify | ||
| this block as a moe block. This should be safe as they are read-only. | ||
| """ | ||
|
|
||
| def get_moe_modules(model): | ||
| """Extract all MoE modules from the model.""" | ||
| moe_modules = [] | ||
| if hasattr(model, "layers"): | ||
| if isinstance(model.layers, torch.nn.ModuleDict): | ||
| # regular torchtitan structure | ||
| blocks = model.layers.values() | ||
| else: | ||
| # autoparallel might change structure | ||
| blocks = ( | ||
| model.layers.children() if hasattr(model.layers, "children") else [] | ||
| ) | ||
|
|
||
| for block in blocks: | ||
| if ( | ||
| hasattr(block, "moe_enabled") | ||
| and block.moe_enabled | ||
| and hasattr(block, "moe") | ||
| ): | ||
| moe_modules.append(block.moe) | ||
| elif hasattr(block, "moe"): # fallback for autoparallel | ||
| moe_modules.append(block.moe) | ||
| return moe_modules | ||
|
|
||
| original_moe_modules = get_moe_modules(original_model) | ||
| parallel_moe_modules = get_moe_modules(parallel_model) | ||
|
|
||
| # Copy custom attributes from original to parallel MoE modules | ||
| # This is fine to do since these attributes are read only | ||
| for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): | ||
| if hasattr(orig_moe, "moe_enabled"): | ||
| par_moe.load_balance_coeff = orig_moe.load_balance_coeff | ||
|
|
||
| # Copy load_balance_coeff | ||
| if hasattr(orig_moe, "load_balance_coeff"): | ||
| par_moe.load_balance_coeff = orig_moe.load_balance_coeff | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay for now, but when we enable TP, this should change so just add a comment that modify this when enabling TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an assert to check for dp2ep specifically, because i don't think we're handling the mesh setup well...