Skip to content

Commit 71cb39b

Browse files
committed
add manual bucketing pass
1 parent 81a36c5 commit 71cb39b

File tree

4 files changed

+50
-14
lines changed

4 files changed

+50
-14
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
5151

5252
2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
5353
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.
54-
55-
56-
users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:
57-
58-
```bash
59-
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
60-
```
54+
```bash
55+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
56+
```
57+
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
58+
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
59+
```bash
60+
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_manualbucketing" --compile.manual_bucketed_modules "tok_embeddings,layers.[0-5],norm+output"
61+
```
6162

6263
### Citation
6364

torchtitan/experiments/simple_fsdp/backend.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
from typing import Any, Union
88

99
import torch
10+
from torchtitan.config import JobConfig
1011

1112

12-
def get_compile_backend(backend_name: str) -> Union[str, callable]:
13+
def get_compile_backend(job_config: JobConfig) -> Union[str, callable]:
1314
# return the compile backends used in SimpleFSDP training
1415
# Step1: check if backend_name is inside available torch.compile backends
1516
# Step2: check if the backend_name has been registered as a customized backend
17+
backend_name = (
18+
getattr(job_config.compile, "model_backend_override", None)
19+
or job_config.compile.backend
20+
)
21+
1622
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
1723
if backend_name in available_torch_backend:
1824
return backend_name
@@ -41,6 +47,36 @@ def aten_autobucketing_reordering_pass(
4147
bw_compiler=aten_autobucketing_reordering_pass,
4248
keep_inference_input_mutations=True,
4349
)
50+
elif backend_name == "aot_eager_manualbucketing":
51+
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
52+
# The manualbucketing logic is here:
53+
bucketing_modules = job_config.compile.manual_bucketed_modules
54+
from functools import partial
55+
56+
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
57+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
58+
manual_overlap_bucketing,
59+
)
60+
61+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
62+
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
63+
torch._inductor.config.allow_buffer_reuse = False
64+
manual_overlap_bucketing = partial(
65+
manual_overlap_bucketing,
66+
module_bucket_plans=job_config.compile.manual_bucketed_modules,
67+
)
68+
69+
def aten_manualbucketing_reordering_pass(
70+
gm: torch.fx.GraphModule, example_inputs: Any
71+
) -> torch.fx.GraphModule:
72+
manual_overlap_bucketing(gm)
73+
return gm
74+
75+
backend = aot_autograd_backend(
76+
fw_compiler=aten_manualbucketing_reordering_pass,
77+
bw_compiler=aten_manualbucketing_reordering_pass,
78+
keep_inference_input_mutations=True,
79+
)
4480
else:
4581
raise AssertionError(f"Unsupported customized backend: {backend_name}")
4682

torchtitan/experiments/simple_fsdp/job_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
@dataclass
1111
class Compile:
1212
model_backend_override: str | None = None
13-
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
13+
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
14+
15+
manual_bucketed_modules: list[str] = field(default_factory=list)
16+
"""Which modules should be bucketed together based on user specifications in manual optimization """
1417

1518

1619
@dataclass

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,9 @@ def parallelize_llama(
125125

126126
if job_config.compile.enable and "model" in job_config.compile.components:
127127
torch._inductor.config.reorder_for_peak_memory = False
128-
backend = (
129-
getattr(job_config.compile, "model_backend_override", None)
130-
or job_config.compile.backend
131-
)
132128
model = torch.compile(
133129
model,
134-
backend=get_compile_backend(backend),
130+
backend=get_compile_backend(job_config),
135131
fullgraph=True,
136132
)
137133

0 commit comments

Comments
 (0)