Skip to content

Commit 0577464

Browse files
committed
add joint graph runner deepseek_v3 experiment
ghstack-source-id: d5493a2 Pull Request resolved: #1906
1 parent db82f8b commit 0577464

File tree

5 files changed

+445
-1
lines changed

5 files changed

+445
-1
lines changed

torchtitan/experiments/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
_supported_experiments = frozenset(
8-
["flux", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
8+
[
9+
"flux",
10+
"simple_fsdp.llama3",
11+
"simple_fsdp.deepseek_v3",
12+
"vlm",
13+
"compiler_toolkit.deepseek_v3",
14+
]
915
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Compiler Toolkit
2+
3+
Exploring toolkit-style use of the compiler stack for authoring parallel models.
4+
5+
Joint Graph based Training Prototype:
6+
7+
DeepSeek v3
8+
- DTensor based model authoring
9+
- Trace joint graph
10+
- Apply optimizations to the joint/fw/bw graphs
11+
- Run using the aot_compile_joint_with_descriptors API
12+
13+
Run with: NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
15+
16+
from torchtitan.experiments.simple_fsdp.deepseek_v3.model import (
17+
SimpleFSDPDeepSeekV3Model,
18+
)
19+
from torchtitan.models.deepseek_v3 import deepseekv3_args
20+
from torchtitan.protocols.train_spec import TrainSpec
21+
22+
from .parallelize import parallelize_deepseekv3
23+
24+
25+
def get_train_spec() -> TrainSpec:
26+
return TrainSpec(
27+
model_cls=SimpleFSDPDeepSeekV3Model,
28+
model_args=deepseekv3_args,
29+
parallelize_fn=parallelize_deepseekv3,
30+
pipelining_fn=pipeline_llm,
31+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
32+
build_lr_schedulers_fn=build_lr_schedulers,
33+
build_dataloader_fn=build_hf_dataloader,
34+
build_tokenizer_fn=build_hf_tokenizer,
35+
build_loss_fn=build_cross_entropy_loss,
36+
)
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torch._functorch.aot_autograd import (
13+
aot_compile_joint_with_descriptors
14+
)
15+
from torch._guards import tracing, TracingContext
16+
17+
from torch.distributed.device_mesh import DeviceMesh
18+
from torch.distributed.tensor import DTensor, Replicate
19+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
20+
from torchtitan.distributed import ParallelDims
21+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
22+
23+
from torchtitan.experiments.simple_fsdp.deepseek_v3.model import (
24+
SimpleFSDPDeepSeekV3Model,
25+
)
26+
from torchtitan.experiments.simple_fsdp.simple_fsdp import (
27+
data_parallel,
28+
MixedPrecisionPolicy,
29+
)
30+
31+
from torchtitan.models.deepseek_v3.infra.parallelize import (
32+
apply_ac,
33+
apply_moe_ep_tp,
34+
apply_non_moe_tp,
35+
)
36+
from torchtitan.tools.logging import logger
37+
38+
from torchtitan.experiments.compiler_toolkit.graph_utils import export_joint, print_if_rank0
39+
40+
41+
# Adapted from llama4/infra/parallelize.py
42+
def parallelize_deepseekv3(
43+
model: nn.Module,
44+
parallel_dims: ParallelDims,
45+
job_config: JobConfig,
46+
):
47+
world_mesh = parallel_dims.world_mesh
48+
# TODO: TP currently cannot handle uneven seq_len because we set
49+
# `use_local_output=True` to use plain Tensors for legacy reasons.
50+
# Need to revisit this.
51+
assert (
52+
job_config.training.seq_len % parallel_dims.seq_len_divisor == 0
53+
), f"""
54+
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
55+
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}), i.e. {parallel_dims.seq_len_divisor}.
56+
"""
57+
if (
58+
job_config.parallelism.context_parallel_degree > 1
59+
and model.model_args.use_flex_attn
60+
):
61+
raise NotImplementedError("CP support for FlexAttention is still in progress.")
62+
63+
if parallel_dims.tp_enabled:
64+
enable_float8_linear = "float8" in job_config.model.converters
65+
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
66+
"rowwise",
67+
"rowwise_with_gw_hp",
68+
)
69+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
70+
if enable_float8_tensorwise_tp:
71+
# TODO(jianiw): This branch needs to be tested and enabled
72+
raise NotImplementedError(
73+
"Currently, float8 tensorwise TP is not tested for deepseekv3"
74+
)
75+
76+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
77+
apply_non_moe_tp(
78+
model,
79+
world_mesh["tp"],
80+
loss_parallel=not job_config.parallelism.disable_loss_parallel,
81+
enable_float8_tensorwise_tp=False,
82+
use_flex_attn=use_flex_attn,
83+
)
84+
maybe_enable_async_tp(job_config, world_mesh["tp"])
85+
86+
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
87+
apply_moe_ep_tp(
88+
model,
89+
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
90+
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
91+
ep_tp_mesh=(
92+
world_mesh["ep", "tp"]
93+
if parallel_dims.tp_enabled
94+
and parallel_dims.ep_enabled
95+
and parallel_dims.etp_enabled
96+
else None
97+
),
98+
etp_enabled=parallel_dims.etp_enabled,
99+
)
100+
if job_config.activation_checkpoint.mode != "none":
101+
apply_ac(model, job_config.activation_checkpoint)
102+
103+
mp_policy = MixedPrecisionPolicy(
104+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
105+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
106+
)
107+
108+
# apply data parallel
109+
dp_mesh: DeviceMesh | None = None
110+
if (
111+
parallel_dims.fsdp_enabled
112+
or parallel_dims.ep_enabled
113+
or parallel_dims.dp_replicate_enabled
114+
):
115+
if parallel_dims.dp_replicate_enabled:
116+
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
117+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
118+
dp_mode = "hybrid_shard"
119+
else:
120+
dp_mesh_dim_names = ("dp_replicate",)
121+
dp_mode = "replicate"
122+
else:
123+
dp_mesh_dim_names = ("dp_shard_cp",)
124+
dp_mode = "fully_shard"
125+
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
126+
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
127+
dp_mod_ep_mesh_dim_names = []
128+
if parallel_dims.ep_enabled:
129+
if parallel_dims.dp_replicate_enabled:
130+
dp_mod_ep_mesh_dim_names.append("dp_replicate")
131+
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
132+
dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
133+
for _, transformer_block in model.layers.items():
134+
if transformer_block.moe_enabled and parallel_dims.ep_enabled:
135+
experts_shard_dim = 0
136+
assert dp_mod_ep_mesh is not None
137+
assert hasattr(transformer_block, "moe")
138+
if (
139+
dp_mod_ep_mesh.size() * parallel_dims.ep
140+
> transformer_block.moe.experts.num_experts
141+
):
142+
experts_shard_dim = 1
143+
transformer_block.moe.experts = data_parallel(
144+
transformer_block.moe.experts,
145+
dp_mod_ep_mesh,
146+
dp_mode,
147+
ac_mode=job_config.activation_checkpoint.mode,
148+
mp_policy=mp_policy,
149+
shard_dim=experts_shard_dim,
150+
)
151+
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
152+
# transformer_block.moe.experts.set_gradient_divide_factor(
153+
# parallel_dims.fsdp_gradient_divide_factor,
154+
# )
155+
model = data_parallel(
156+
model,
157+
dp_mesh,
158+
dp_mode,
159+
ac_mode=job_config.activation_checkpoint.mode,
160+
mp_policy=mp_policy,
161+
)
162+
logger.info(
163+
"Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode
164+
)
165+
if job_config.compile.enable:
166+
# TODO: CompiledModule should take sample input as well, so that we can
167+
# compile ahead of time.
168+
model = CompiledModule(model, parallel_dims)
169+
170+
return model
171+
172+
173+
class CompiledModule(torch.nn.Module):
174+
def __init__(self, inner: torch.nn.Module, parallel_dims, **overrides):
175+
super().__init__()
176+
self.inner = inner # register as submodule
177+
self.parallel_dims = parallel_dims
178+
179+
self.joint_graph_module = None
180+
self._overrides = overrides # for custom hooks
181+
182+
def __getattr__(self, name):
183+
# check overrides
184+
if "_overrides" in self.__dict__ and name in self._overrides:
185+
return self._overrides[name]
186+
try:
187+
# let nn.Module handle registered stuff
188+
return super().__getattr__(name)
189+
except AttributeError:
190+
# fallback to inner model
191+
return getattr(self.inner, name)
192+
193+
def __setattr__(self, name, value):
194+
if "_overrides" in self.__dict__ and name in self._overrides:
195+
self._overrides[name] = value
196+
else:
197+
super().__setattr__(name, value)
198+
199+
def __delattr__(self, name):
200+
if "_overrides" in self.__dict__ and name in self._overrides:
201+
del self._overrides[name]
202+
else:
203+
super().__delattr__(name)
204+
205+
def forward(self, *args, **kwargs):
206+
assert "forward" not in self._overrides, "forward cannot be overridden"
207+
dt_args = tuple(
208+
DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()])
209+
for arg in args
210+
)
211+
if self.joint_graph_module is None:
212+
self.joint_graph_module = joint_graph_builder(
213+
self.inner, *dt_args, **kwargs
214+
)
215+
216+
# calling the line below returns control to torchtitan's runner
217+
# letting it call the backward, and optimizer.
218+
219+
# TODO: add support for kwargs
220+
return self.joint_graph_module(args)
221+
222+
223+
def joint_graph_builder(model, *inputs, **kwargs):
224+
assert isinstance(inputs, tuple)
225+
for input in inputs:
226+
assert isinstance(input, DTensor)
227+
228+
# get joint graph
229+
(
230+
joint_with_descriptors,
231+
tracing_context,
232+
) = export_joint(model, inputs)
233+
234+
def fw_compiler(gm: torch.fx.GraphModule, example_inputs):
235+
print_if_rank0("fwd_gm:")
236+
print_if_rank0(gm.print_readable(print_output=False))
237+
238+
# print_if_rank0("After compiler:")
239+
# print_if_rank0(gm.print_readable(print_output=False))
240+
return gm
241+
242+
def bw_compiler(gm: torch.fx.GraphModule, example_inputs):
243+
print_if_rank0("bwd_gm:")
244+
print_if_rank0(gm.print_readable(print_output=False))
245+
246+
# print_if_rank0("After compiler:")
247+
# print_if_rank0(gm.print_readable(print_output=False))
248+
return gm
249+
250+
with tracing(tracing_context):
251+
fn = aot_compile_joint_with_descriptors(
252+
joint_with_descriptors, fw_compiler=fw_compiler, bw_compiler=bw_compiler
253+
)
254+
255+
def wrapper_fn(args):
256+
input = [
257+
*model.parameters(),
258+
*model.buffers(),
259+
*args,
260+
]
261+
return fn(*input)
262+
263+
return wrapper_fn

0 commit comments

Comments
 (0)