Skip to content

Commit d0e2545

Browse files
add auto_eager_graph_pass (#1813)
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: pytorch/pytorch#163960. Key updates are: 1. Support customized `aot_eger_autobucketing` backend to perform autobucketing optimization. 2. In simplefsdp, the model_backend can be replaced by user's customized passes using `compile.model_backend_override`.
1 parent 7c10480 commit d0e2545

File tree

5 files changed

+126
-2
lines changed

5 files changed

+126
-2
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu
1010

1111
This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.
1212

13-
### Run SimpleFSDP Training on Llama 3
13+
### Run SimpleFSDP Training on Llama3 & DeepSeek_v3
1414

1515
#### Training Llama3 models
1616

@@ -42,6 +42,23 @@ Some of the features require the updates from PyTorch, with which we are working
4242
|Expert Parallelism + Activation Checkpointing| 🚧 |
4343
|Expert Parallelism + Pipeline Parallelism| 🚧 |
4444

45+
46+
### Compiler Optimizations
47+
48+
SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing & reordering) for good training performance. Currently, the following optimization passes are supported:
49+
50+
1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager")
51+
52+
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
53+
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54+
55+
56+
users can specify the pass (e.g., "aot_eager_autobucketing") via addtional configs:
57+
58+
```bash
59+
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
60+
```
61+
4562
### Citation
4663

4764
If you find SimpleFSDP useful, please kindly consider citing the following paper:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Union
8+
9+
import torch
10+
11+
12+
def get_compile_backend(backend_name: str) -> Union[str, callable]:
13+
# return the compile backends used in SimpleFSDP training
14+
# Step1: check if backend_name is inside available torch.compile backends
15+
# Step2: check if the backend_name has been registered as a customized backend
16+
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
17+
if backend_name in available_torch_backend:
18+
return backend_name
19+
20+
if backend_name == "aot_eager_autobucketing":
21+
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
22+
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
23+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
24+
from torch._inductor.fx_passes.overlap_scheduling import (
25+
schedule_overlap_bucketing,
26+
)
27+
28+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
29+
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
30+
torch._inductor.config.allow_buffer_reuse = False
31+
32+
def aten_autobucketing_reordering_pass(
33+
gm: torch.fx.GraphModule, example_inputs: Any
34+
) -> torch.fx.GraphModule:
35+
schedule_overlap_bucketing(gm)
36+
gm.recompile()
37+
return gm
38+
39+
backend = aot_autograd_backend(
40+
fw_compiler=aten_autobucketing_reordering_pass,
41+
bw_compiler=aten_autobucketing_reordering_pass,
42+
keep_inference_input_mutations=True,
43+
)
44+
else:
45+
raise AssertionError(f"Unsupported customized backend: {backend_name}")
46+
47+
return backend
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
9+
10+
@dataclass
11+
class Compile:
12+
model_backend_override: str | None = None
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
14+
15+
16+
@dataclass
17+
class JobConfig:
18+
compile: Compile = field(default_factory=Compile)

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torchtitan.models.llama3.infra.parallelize import apply_tp
1515
from torchtitan.tools.logging import logger
1616

17+
from ..backend import get_compile_backend
18+
1719
from ..simple_fsdp import data_parallel, MixedPrecisionPolicy
1820

1921

@@ -123,6 +125,13 @@ def parallelize_llama(
123125

124126
if job_config.compile.enable and "model" in job_config.compile.components:
125127
torch._inductor.config.reorder_for_peak_memory = False
126-
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
128+
backend = (
129+
job_config.compile.model_backend_override or job_config.compile.backend
130+
)
131+
model = torch.compile(
132+
model,
133+
backend=get_compile_backend(backend),
134+
fullgraph=True,
135+
)
127136

128137
return model

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,32 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
2323
[
2424
"--model.name simple_fsdp.llama3",
2525
"--compile.enable",
26+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
2627
],
2728
],
2829
"1D",
2930
"1d",
3031
),
32+
OverrideDefinitions(
33+
[
34+
[
35+
"--model.name simple_fsdp.llama3",
36+
"--compile.enable",
37+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
38+
"--compile.model_backend_override aot_eager_autobucketing",
39+
],
40+
],
41+
"1D+aot_eager_autobucketing",
42+
"1d_aot_eager_autobucketing",
43+
),
3144
OverrideDefinitions(
3245
[
3346
[
3447
"--model.name simple_fsdp.llama3",
3548
"--compile.enable",
3649
"--activation_checkpoint.mode selective",
3750
"--activation_checkpoint.selective_ac_option op",
51+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
3852
],
3953
],
4054
"1D with selective op AC",
@@ -46,6 +60,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
4660
"--model.name simple_fsdp.llama3",
4761
"--compile.enable",
4862
"--activation_checkpoint.mode full",
63+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
4964
],
5065
],
5166
"1D with full AC",
@@ -57,6 +72,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
5772
"--model.name simple_fsdp.llama3",
5873
"--compile.enable",
5974
"--parallelism.tensor_parallel_degree 2",
75+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
6076
],
6177
],
6278
"2D",
@@ -70,6 +86,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
7086
"--compile.enable",
7187
"--parallelism.tensor_parallel_degree 2",
7288
"--parallelism.enable_async_tensor_parallel",
89+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
7390
],
7491
],
7592
"2D async TP",
@@ -82,12 +99,14 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
8299
"--model.name simple_fsdp.llama3",
83100
"--compile.enable",
84101
"--checkpoint.enable",
102+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
85103
],
86104
[
87105
"--model.name simple_fsdp.llama3",
88106
"--compile.enable",
89107
"--checkpoint.enable",
90108
"--training.steps 20",
109+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
91110
],
92111
],
93112
"Checkpoint Integration Test - Save Load Full Checkpoint",
@@ -102,6 +121,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
102121
"--parallelism.pipeline_parallel_degree 2",
103122
"--parallelism.data_parallel_shard_degree 2",
104123
"--parallelism.tensor_parallel_degree 2",
124+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
105125
],
106126
[
107127
"--model.name simple_fsdp.llama3",
@@ -111,6 +131,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
111131
"--parallelism.pipeline_parallel_degree 2",
112132
"--parallelism.data_parallel_shard_degree 2",
113133
"--parallelism.tensor_parallel_degree 2",
134+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
114135
],
115136
],
116137
"PP+DP+TP 3D test with save/load resume ckpt",
@@ -124,6 +145,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
124145
"--compile.enable",
125146
"--parallelism.data_parallel_shard_degree 1",
126147
"--parallelism.data_parallel_replicate_degree 4",
148+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
127149
]
128150
],
129151
"DDP",
@@ -137,6 +159,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
137159
"--compile.enable",
138160
"--parallelism.data_parallel_shard_degree 2",
139161
"--parallelism.data_parallel_replicate_degree 2",
162+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
140163
]
141164
],
142165
"HSDP",
@@ -151,6 +174,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
151174
"--parallelism.data_parallel_shard_degree 2",
152175
"--parallelism.data_parallel_replicate_degree 2",
153176
"--parallelism.tensor_parallel_degree 2",
177+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
154178
]
155179
],
156180
"HSDP+TP",
@@ -164,6 +188,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
164188
"--compile.enable",
165189
"--parallelism.data_parallel_replicate_degree 2",
166190
"--parallelism.tensor_parallel_degree 2",
191+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
167192
]
168193
],
169194
"DDP+TP",
@@ -178,6 +203,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
178203
"--parallelism.data_parallel_shard_degree 2",
179204
"--parallelism.data_parallel_replicate_degree 2",
180205
"--parallelism.context_parallel_degree 2",
206+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
181207
]
182208
],
183209
"HSDP+CP (with dp_shard)",
@@ -192,6 +218,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
192218
"--parallelism.data_parallel_shard_degree 2",
193219
"--parallelism.tensor_parallel_degree 2",
194220
"--parallelism.context_parallel_degree 2",
221+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
195222
]
196223
],
197224
"FSDP+TP+CP",
@@ -205,6 +232,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
205232
"--compile.enable",
206233
"--checkpoint.enable",
207234
"--training.steps 10",
235+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
208236
],
209237
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
210238
# excluded during loading to avoid errors caused by mismatched dp_degree.
@@ -215,6 +243,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
215243
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
216244
"--parallelism.tensor_parallel_degree 2",
217245
"--training.steps 20",
246+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
218247
],
219248
# load at [tp:4].
220249
[
@@ -224,6 +253,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
224253
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
225254
"--parallelism.tensor_parallel_degree 4",
226255
"--training.steps 30",
256+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
227257
],
228258
],
229259
"Optional checkpoint",
@@ -236,6 +266,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
236266
"--model.name simple_fsdp.deepseek_v3",
237267
"--parallelism.data_parallel_shard_degree 4",
238268
"--parallelism.expert_parallel_degree 2",
269+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
239270
],
240271
],
241272
"FSDP+EP",
@@ -250,6 +281,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
250281
"--parallelism.tensor_parallel_degree 2",
251282
"--parallelism.expert_parallel_degree 4",
252283
"--parallelism.expert_tensor_parallel_degree 1",
284+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
253285
],
254286
],
255287
"FSDP+TP+EP",
@@ -264,6 +296,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
264296
"--parallelism.tensor_parallel_degree 2",
265297
"--parallelism.expert_parallel_degree 2",
266298
"--parallelism.expert_tensor_parallel_degree 2",
299+
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
267300
],
268301
],
269302
"FSDP+TP+EP+ETP",

0 commit comments

Comments
 (0)