Skip to content

Commit 3c46d64

Browse files
committed
add manual bucketing pass
1 parent 81a36c5 commit 3c46d64

File tree

5 files changed

+61
-28
lines changed

5 files changed

+61
-28
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5151

5252
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
5353
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54-
55-
56-
users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:
57-
58-
```bash
59-
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
60-
```
54+
```bash
55+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
56+
```
57+
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
58+
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
59+
```bash
60+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_manualbucketing" --compile.manual_bucketed_modules "tok_embeddings,layers.[0-5],norm+output"
61+
```
6162

6263
### Citation
6364

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
from typing import Any, Union
88

99
import torch
10+
from torchtitan.config import JobConfig
1011

1112

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
13+
def get_compile_backend(job_config: JobConfig) -> Union[str, callable]:
1314
# return the compile backends used in SimpleFSDP training
1415
# Step1: check if backend_name is inside available torch.compile backends
1516
# Step2: check if the backend_name has been registered as a customized backend
17+
backend_name = (
18+
getattr(job_config.compile, "model_backend_override", None)
19+
or job_config.compile.backend
20+
)
21+
1622
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
1723
if backend_name in available_torch_backend:
1824
return backend_name
@@ -25,8 +31,6 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]:
2531
schedule_overlap_bucketing,
2632
)
2733

28-
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
29-
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
3034
torch._inductor.config.allow_buffer_reuse = False
3135

3236
def aten_autobucketing_reordering_pass(
@@ -41,6 +45,34 @@ def aten_autobucketing_reordering_pass(
4145
bw_compiler=aten_autobucketing_reordering_pass,
4246
keep_inference_input_mutations=True,
4347
)
48+
elif backend_name == "aot_eager_manualbucketing":
49+
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
50+
# The manualbucketing logic is here:
51+
bucketing_modules = job_config.compile.manual_bucketed_modules
52+
from functools import partial
53+
54+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
55+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
56+
manual_overlap_bucketing,
57+
)
58+
59+
torch._inductor.config.allow_buffer_reuse = False
60+
manual_overlap_bucketing = partial(
61+
manual_overlap_bucketing,
62+
module_bucket_plans=job_config.compile.manual_bucketed_modules,
63+
)
64+
65+
def aten_manualbucketing_reordering_pass(
66+
gm: torch.fx.GraphModule, example_inputs: Any
67+
) -> torch.fx.GraphModule:
68+
manual_overlap_bucketing(gm)
69+
return gm
70+
71+
backend = aot_autograd_backend(
72+
fw_compiler=aten_manualbucketing_reordering_pass,
73+
bw_compiler=aten_manualbucketing_reordering_pass,
74+
keep_inference_input_mutations=True,
75+
)
4476
else:
4577
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4678

torchtitan/experiments/simple_fsdp/job_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
@dataclass
1111
class Compile:
1212
model_backend_override: str | None = None
13-
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
14+
15+
manual_bucketed_modules: list[str] = field(default_factory=list)
16+
"""Which modules should be bucketed together based on user specifications in manual optimization """
1417

1518

1619
@dataclass

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,9 @@ def parallelize_llama(
125125

126126
if job_config.compile.enable and "model" in job_config.compile.components:
127127
torch._inductor.config.reorder_for_peak_memory = False
128-
backend = (
129-
getattr(job_config.compile, "model_backend_override", None)
130-
or job_config.compile.backend
131-
)
132128
model = torch.compile(
133129
model,
134-
backend=get_compile_backend(backend),
130+
backend=get_compile_backend(job_config),
135131
fullgraph=True,
136132
)
137133

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
2929
"1D",
3030
"1d",
3131
),
32-
OverrideDefinitions(
33-
[
34-
[
35-
"--model.name simple_fsdp.llama3",
36-
"--compile.enable",
37-
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
38-
"--compile.model_backend_override aot_eager_autobucketing",
39-
],
40-
],
41-
"1D+aot_eager_autobucketing",
42-
"1d_aot_eager_autobucketing",
43-
),
32+
# TODO(ruisizhang123): add back after autobucketing pass is mature
33+
# OverrideDefinitions(
34+
# [
35+
# [
36+
# "--model.name simple_fsdp.llama3",
37+
# "--compile.enable",
38+
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
39+
# "--compile.model_backend_override aot_eager_autobucketing",
40+
# ],
41+
# ],
42+
# "1D+aot_eager_autobucketing",
43+
# "1d_aot_eager_autobucketing",
44+
# ),
4445
OverrideDefinitions(
4546
[
4647
[

0 commit comments

Comments
 (0)