Skip to content

Commit c62d60f

Browse files
committed
add atenbucketing pass
1 parent db22479 commit c62d60f

File tree

5 files changed

+56
-8
lines changed

5 files changed

+56
-8
lines changed

torchtitan/config/job_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ class Experimental:
737737

738738
autop_force_bf16: bool = False
739739

740-
enable_simplefsdp_passes: bool = False
740+
enable_autobucketing_passes: str = ""
741741

742742
@dataclass
743743
class Validation:

torchtitan/experiments/auto_parallel/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ requires installing [email protected]:pytorch-labs/autoparallel.git
44

55
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`
66

7-
Use simplefsdp's autobucketing pass:
7+
Use autobucketing pass:
88

9-
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable`
9+
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_autobucketing_passes "aten" --compile.enable`
10+
11+
Set `experimental.enable_autobucketing_passes` to
1012

1113
(or llama3-8b.toml)

torchtitan/experiments/simple_fsdp/parallelize.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,45 @@ def parallelize_llama(
9393
)
9494
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
9595

96-
if job_config.compile.enable and "model" in job_config.compile.components:
97-
torch._inductor.config.reorder_for_peak_memory = False
96+
if job_config.compile.enable:
97+
from functools import partial
98+
bucket_level = ""
99+
torch._inductor.config.run_with_post_grad_graph = False
100+
if bucket_level == "inductor":
101+
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
102+
from autoparallel.auto_bucketing import (
103+
simple_fsdp_autobucketing_reordering_pass,
104+
simplefsdp_autobucketing_config,
105+
)
106+
107+
torch._inductor.config.allow_buffer_reuse = False
108+
torch._inductor.config.reorder_for_peak_memory = False
109+
torch._inductor.config.reorder_for_compute_comm_overlap = True
110+
simplefsdp_autobucketing_config.save_estimation_path = (
111+
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
112+
)
113+
simplefsdp_autobucketing_config.calibrate_number = 20
114+
simple_fsdp_autobucketing_reordering_pass = partial(
115+
simple_fsdp_autobucketing_reordering_pass,
116+
configs=simplefsdp_autobucketing_config,
117+
)
118+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
119+
simple_fsdp_autobucketing_reordering_pass
120+
]
121+
122+
# Don't use both sets of passes at the same time!
123+
torch._inductor.config.bucket_all_gathers_fx = "none"
124+
torch._inductor.config.bucket_reduce_scatters_fx = "none"
125+
elif bucket_level == "aten":
126+
from autoparallel.auto_bucketing import aten_autobucketing_reordering_pass, aten_autobucketing_config
127+
torch._inductor.config.reorder_for_peak_memory = False
128+
torch._inductor.config.reorder_for_compute_comm_overlap = False
129+
aten_autobucketing_reordering_pass = partial(
130+
aten_autobucketing_reordering_pass,
131+
configs=aten_autobucketing_config,
132+
)
133+
torch._inductor.config.post_grad_custom_post_pass = aten_autobucketing_reordering_pass
134+
98135
model = torch.compile(model, fullgraph=True)
99136

100137
return model

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ local_batch_size = 1
3434
seq_len = 8192
3535
max_norm = 1.0 # grad norm clipping
3636
steps = 1000
37-
dataset = "c4"
37+
dataset = "c4_test"
3838

3939
[parallelism]
4040
data_parallel_replicate_degree = 1

torchtitan/train.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,8 @@ def __init__(self, job_config: JobConfig):
128128
torch._inductor.config.force_disable_caches = True
129129
# this is necessary for working with reordering passes. Just leave it set for all the jobs for now.
130130
torch._inductor.config.allow_buffer_reuse = False
131-
132131
# allow configuring inductor comms optimizations from torchtitan commandline
133-
if job_config.experimental.enable_simplefsdp_passes:
132+
if job_config.experimental.enable_autobucketing_passes == "inductor":
134133
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
135134
from autoparallel.auto_bucketing import (
136135
simple_fsdp_autobucketing_reordering_pass,
@@ -143,6 +142,7 @@ def __init__(self, job_config: JobConfig):
143142
simplefsdp_autobucketing_config.save_estimation_path = (
144143
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
145144
)
145+
simplefsdp_autobucketing_config.calibrate_number = 20
146146
simple_fsdp_autobucketing_reordering_pass = partial(
147147
simple_fsdp_autobucketing_reordering_pass,
148148
configs=simplefsdp_autobucketing_config,
@@ -154,6 +154,15 @@ def __init__(self, job_config: JobConfig):
154154
# Don't use both sets of passes at the same time!
155155
torch._inductor.config.bucket_all_gathers_fx = "none"
156156
torch._inductor.config.bucket_reduce_scatters_fx = "none"
157+
elif job_config.experimental.enable_autobucketing_passes == "aten":
158+
from autoparallel.auto_bucketing import aten_autobucketing_reordering_pass, aten_autobucketing_config
159+
torch._inductor.config.reorder_for_peak_memory = False
160+
torch._inductor.config.reorder_for_compute_comm_overlap = False
161+
aten_autobucketing_reordering_pass = partial(
162+
aten_autobucketing_reordering_pass,
163+
configs=aten_autobucketing_config,
164+
)
165+
torch._inductor.config.post_grad_custom_post_pass = aten_autobucketing_reordering_pass
157166
else:
158167
torch._inductor.config.bucket_all_gathers_fx = (
159168
job_config.experimental.bucket_all_gathers_fx

0 commit comments

Comments
 (0)