Skip to content

Commit 6e2a44f

Browse files
committed
add atenbucketing pass
1 parent db22479 commit 6e2a44f

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

torchtitan/experiments/simple_fsdp/parallelize.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,45 @@ def parallelize_llama(
9393
)
9494
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
9595

96-
if job_config.compile.enable and "model" in job_config.compile.components:
97-
torch._inductor.config.reorder_for_peak_memory = False
96+
if job_config.compile.enable:
97+
from functools import partial
98+
bucket_level = ""
99+
torch._inductor.config.run_with_post_grad_graph = True
100+
if bucket_level == "inductor":
101+
# enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
102+
from autoparallel.auto_bucketing import (
103+
simple_fsdp_autobucketing_reordering_pass,
104+
simplefsdp_autobucketing_config,
105+
)
106+
107+
torch._inductor.config.allow_buffer_reuse = False
108+
torch._inductor.config.reorder_for_peak_memory = False
109+
torch._inductor.config.reorder_for_compute_comm_overlap = True
110+
simplefsdp_autobucketing_config.save_estimation_path = (
111+
"/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
112+
)
113+
simplefsdp_autobucketing_config.calibrate_number = 20
114+
simple_fsdp_autobucketing_reordering_pass = partial(
115+
simple_fsdp_autobucketing_reordering_pass,
116+
configs=simplefsdp_autobucketing_config,
117+
)
118+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
119+
simple_fsdp_autobucketing_reordering_pass
120+
]
121+
122+
# Don't use both sets of passes at the same time!
123+
torch._inductor.config.bucket_all_gathers_fx = "none"
124+
torch._inductor.config.bucket_reduce_scatters_fx = "none"
125+
elif bucket_level == "aten":
126+
from autoparallel.auto_bucketing import aten_autobucketing_reordering_pass, aten_autobucketing_config
127+
torch._inductor.config.reorder_for_peak_memory = False
128+
torch._inductor.config.reorder_for_compute_comm_overlap = False
129+
aten_autobucketing_reordering_pass = partial(
130+
aten_autobucketing_reordering_pass,
131+
configs=aten_autobucketing_config,
132+
)
133+
torch._inductor.config.post_grad_custom_post_pass = aten_autobucketing_reordering_pass
134+
98135
model = torch.compile(model, fullgraph=True)
99136

100137
return model

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ local_batch_size = 1
3434
seq_len = 8192
3535
max_norm = 1.0 # grad norm clipping
3636
steps = 1000
37-
dataset = "c4"
37+
dataset = "c4_test"
3838

3939
[parallelism]
4040
data_parallel_replicate_degree = 1

0 commit comments

Comments
 (0)