Skip to content

Commit c30bf6a

Browse files
committed
[Autoparallel] Add local_map variant of DSv3 and 2D mesh AP
stack-info: PR: #2129, branch: xmfan/stack/7
1 parent 1494ccc commit c30bf6a

File tree

7 files changed

+304
-2
lines changed

7 files changed

+304
-2
lines changed

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
"transformers_modeling_backend",
1616
"autoparallel.llama3",
1717
"autoparallel.deepseek_v3",
18+
"autoparallel.local_map_deepseek_v3",
1819
]
1920
)

torchtitan/experiments/autoparallel/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@ Requires installing [[email protected]:meta-pytorch/autoparallel.git](https://githu
1717
**DeepSeekv3**
1818

1919
`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config`
20+
21+
**DeepSeekv3 local_map**
22+
23+
This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep.
24+
25+
`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2`

torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,6 @@ def set_torchtitan_fields(orig, new):
257257
block.moe_enabled = hasattr(block, "moe")
258258

259259

260-
# Run workflow with:
261-
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_autoparallel
262260
def parallelize_deepseekv3(
263261
model,
264262
parallel_dims: ParallelDims,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
import copy
10+
11+
from torchtitan.components.loss import build_cross_entropy_loss
12+
from torchtitan.components.lr_scheduler import build_lr_schedulers
13+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
14+
from torchtitan.components.tokenizer import build_hf_tokenizer
15+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
16+
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
17+
18+
from torchtitan.models.deepseek_v3 import deepseekv3_args
19+
from torchtitan.models.deepseek_v3.model.state_dict_adapter import (
20+
DeepSeekV3StateDictAdapter,
21+
)
22+
from torchtitan.protocols.train_spec import TrainSpec
23+
24+
from .args import DeepSeekV3ModelArgs, get_sample_config
25+
26+
from .model import DeepSeekV3Model
27+
from .parallelize_deepseekv3 import parallelize_deepseekv3
28+
29+
30+
def get_model_args() -> DeepSeekV3ModelArgs:
31+
model_args = copy.deepcopy(deepseekv3_args)
32+
# TODO: Align configs between AP and Titan
33+
for config in model_args.keys():
34+
# Just override the configs
35+
override = get_sample_config()
36+
override.update_from_config = model_args[config].update_from_config
37+
override.get_nparams_and_flops = model_args[config].get_nparams_and_flops
38+
model_args[config] = override
39+
40+
return model_args
41+
42+
43+
def get_train_spec() -> TrainSpec:
44+
model_args = get_model_args()
45+
46+
return TrainSpec(
47+
model_cls=DeepSeekV3Model,
48+
model_args=model_args,
49+
parallelize_fn=parallelize_deepseekv3,
50+
pipelining_fn=pipeline_llm,
51+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
52+
build_lr_schedulers_fn=build_lr_schedulers,
53+
build_dataloader_fn=build_text_dataloader,
54+
build_tokenizer_fn=build_hf_tokenizer,
55+
build_loss_fn=build_cross_entropy_loss,
56+
state_dict_adapter=DeepSeekV3StateDictAdapter,
57+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from dataclasses import dataclass
10+
from autoparallel._testing.models.dsv3 import DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs, MoEArgs as _MoEArgs
11+
from torchtitan.protocols.model import BaseModelArgs
12+
13+
14+
# Need to share same base class with torchtitan models
15+
@dataclass
16+
class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs):
17+
pass
18+
19+
20+
def get_sample_config() -> DeepSeekV3ModelArgs:
21+
return DeepSeekV3ModelArgs(
22+
vocab_size=2048,
23+
max_seq_len=2048,
24+
dim=256,
25+
inter_dim=1024,
26+
moe_inter_dim=256,
27+
n_layers=4,
28+
n_dense_layers=0,
29+
n_heads=16,
30+
moe_args=_MoEArgs(
31+
num_experts=4,
32+
num_shared_experts=2,
33+
top_k=2,
34+
score_func="softmax",
35+
route_norm=False,
36+
score_before_experts=False,
37+
mesh=None,
38+
),
39+
q_lora_rank=0,
40+
kv_lora_rank=512,
41+
qk_nope_head_dim=128,
42+
qk_rope_head_dim=64,
43+
v_head_dim=128,
44+
mscale=0.70,
45+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model
10+
from torchtitan.protocols.train_spec import ModelProtocol
11+
from .args import DeepSeekV3ModelArgs
12+
13+
14+
# Need to share same base class with torchtitan models
15+
class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol):
16+
def __init__(self, model_args: DeepSeekV3ModelArgs):
17+
super().__init__(model_args)
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import torch
10+
from autoparallel.api import AutoParallel
11+
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing
12+
13+
from torch.distributed.tensor.placement_types import Shard
14+
from torchtitan.config import JobConfig
15+
from torchtitan.distributed import ParallelDims
16+
17+
from torchtitan.tools.logging import logger
18+
19+
20+
# TODO: Autoparallel should transparently wrap the original nn.Module
21+
# but I don't know how to do that.
22+
def set_torchtitan_fields(orig, new):
23+
assert isinstance(new.layers, torch.nn.ModuleDict)
24+
for block in new.layers.values():
25+
block.moe_enabled = hasattr(block, "moe")
26+
27+
28+
def parallelize_deepseekv3(
29+
model,
30+
parallel_dims: ParallelDims,
31+
job_config: JobConfig,
32+
):
33+
"""
34+
Apply Autoparallel to the model
35+
36+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
37+
the model must fit on GPU or CPU memory.
38+
"""
39+
40+
# TODO(whc)
41+
# I do this because otherwise sometimes inductor will skip re-running passes like comms reordering
42+
torch._inductor.config.force_disable_caches = True
43+
# this is necessary for working with reordering passes. Just leave it set for all the jobs for now.
44+
torch._inductor.config.allow_buffer_reuse = False
45+
46+
# allow configuring inductor comms optimizations from torchtitan commandline
47+
configure_inductor_for_autobucketing(
48+
job_config.experimental.comms_bucket_reorder_strategy
49+
)
50+
51+
world_mesh = parallel_dims.world_mesh
52+
53+
# Update me when changing dsv3.py
54+
assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims"
55+
56+
# Provide AP MoE with mesh
57+
for layer in model.layers.values():
58+
if layer.moe_enabled:
59+
layer.moe.mesh = world_mesh
60+
layer.moe.axis_name = "dp_shard_in_ep"
61+
62+
def input_fn():
63+
global_batch_size = job_config.training.global_batch_size
64+
if global_batch_size < 0:
65+
# This global batch size results in 1 gradient accumulation
66+
# step.
67+
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
68+
global_batch_size = job_config.training.local_batch_size * dp_degree
69+
return (
70+
torch.randint(
71+
0,
72+
model.model_args.vocab_size,
73+
(global_batch_size, job_config.training.seq_len),
74+
device=torch.device("cuda"),
75+
),
76+
)
77+
78+
should_compile = job_config.compile.enable
79+
if should_compile:
80+
# TODO: support more options in AP API
81+
assert job_config.compile.components == ["model"]
82+
assert job_config.compile.backend == "inductor"
83+
84+
mp_policy = None
85+
with AutoParallel(
86+
model,
87+
input_fn,
88+
world_mesh,
89+
mp_policy=mp_policy,
90+
compile=should_compile,
91+
dynamic=True,
92+
) as autop:
93+
autop.add_parameter_memory_constraint(low=None, high=None)
94+
95+
x_sharding = (Shard(0), Shard(0))
96+
loss_parallel_enabled = (
97+
parallel_dims.tp_enabled
98+
and not job_config.parallelism.disable_loss_parallel
99+
)
100+
assert not loss_parallel_enabled
101+
autop.add_input_constraints([x_sharding])
102+
autop.add_output_constraints([x_sharding])
103+
t0 = time.time()
104+
sharding_placement = autop.optimize_placement()
105+
t1 = time.time()
106+
logger.info(f"AutoParallel took {t1 - t0} seconds")
107+
parallel_mod = autop.apply_placement(sharding_placement)
108+
109+
set_torchtitan_fields(model, parallel_mod)
110+
111+
if loss_parallel_enabled:
112+
113+
# current PyTorch's implementation of loss parallel assumes
114+
# that the DTensor has a 1d device mesh. This is not true
115+
# in our case, but we can work around it by adding
116+
# casting the output to a DTensor on a 1d device mesh.
117+
# We should just use AutoParallel to do this for us, but
118+
# it would require putting the loss inside the model as well
119+
def _return_as_dtensor_for_loss_parallel(module, args, output):
120+
return torch.distributed.tensor.DTensor.from_local(
121+
output, world_mesh["tp"], (Shard(2),)
122+
)
123+
124+
# not keeping a reference to the hook, don't plan on
125+
# removing it at any point
126+
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)
127+
128+
_preserve_moe_attributes(model, parallel_mod)
129+
130+
return parallel_mod
131+
132+
133+
def _preserve_moe_attributes(original_model, parallel_model):
134+
"""
135+
Preserve MoE custom attributes from the original model to the parallel model.
136+
This is only needed for attributes that aren't used in the graph, so they aren't
137+
lifted as graph inputs and fetched by the pre-graph runtime wrapper.
138+
139+
`moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify
140+
this block as a moe block. This should be safe as they are read-only.
141+
"""
142+
143+
def get_moe_modules(model):
144+
"""Extract all MoE modules from the model."""
145+
moe_modules = []
146+
if hasattr(model, "layers"):
147+
if isinstance(model.layers, torch.nn.ModuleDict):
148+
# regular torchtitan structure
149+
blocks = model.layers.values()
150+
else:
151+
# autoparallel might change structure
152+
blocks = (
153+
model.layers.children() if hasattr(model.layers, "children") else []
154+
)
155+
156+
for block in blocks:
157+
if (
158+
hasattr(block, "moe_enabled")
159+
and block.moe_enabled
160+
and hasattr(block, "moe")
161+
):
162+
moe_modules.append(block.moe)
163+
elif hasattr(block, "moe"): # fallback for autoparallel
164+
moe_modules.append(block.moe)
165+
return moe_modules
166+
167+
original_moe_modules = get_moe_modules(original_model)
168+
parallel_moe_modules = get_moe_modules(parallel_model)
169+
170+
# Copy custom attributes from original to parallel MoE modules
171+
# This is fine to do since these attributes are read only
172+
for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules):
173+
if hasattr(orig_moe, "moe_enabled"):
174+
par_moe.load_balance_coeff = orig_moe.load_balance_coeff
175+
176+
# Copy load_balance_coeff
177+
if hasattr(orig_moe, "load_balance_coeff"):
178+
par_moe.load_balance_coeff = orig_moe.load_balance_coeff

0 commit comments

Comments
 (0)