@@ -93,8 +93,45 @@ def parallelize_llama(
93
93
)
94
94
logger .info ("Applied Data Parallel (dp mode=%s) to the model" , dp_mode )
95
95
96
- if job_config .compile .enable and "model" in job_config .compile .components :
97
- torch ._inductor .config .reorder_for_peak_memory = False
96
+ if job_config .compile .enable :
97
+ from functools import partial
98
+ bucket_level = ""
99
+ torch ._inductor .config .run_with_post_grad_graph = True
100
+ if bucket_level == "inductor" :
101
+ # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282)
102
+ from autoparallel .auto_bucketing import (
103
+ simple_fsdp_autobucketing_reordering_pass ,
104
+ simplefsdp_autobucketing_config ,
105
+ )
106
+
107
+ torch ._inductor .config .allow_buffer_reuse = False
108
+ torch ._inductor .config .reorder_for_peak_memory = False
109
+ torch ._inductor .config .reorder_for_compute_comm_overlap = True
110
+ simplefsdp_autobucketing_config .save_estimation_path = (
111
+ "/tmp/torchtitan_simplefsdp_comm_estimation.pkl"
112
+ )
113
+ simplefsdp_autobucketing_config .calibrate_number = 20
114
+ simple_fsdp_autobucketing_reordering_pass = partial (
115
+ simple_fsdp_autobucketing_reordering_pass ,
116
+ configs = simplefsdp_autobucketing_config ,
117
+ )
118
+ torch ._inductor .config .reorder_for_compute_comm_overlap_passes = [
119
+ simple_fsdp_autobucketing_reordering_pass
120
+ ]
121
+
122
+ # Don't use both sets of passes at the same time!
123
+ torch ._inductor .config .bucket_all_gathers_fx = "none"
124
+ torch ._inductor .config .bucket_reduce_scatters_fx = "none"
125
+ elif bucket_level == "aten" :
126
+ from autoparallel .auto_bucketing import aten_autobucketing_reordering_pass , aten_autobucketing_config
127
+ torch ._inductor .config .reorder_for_peak_memory = False
128
+ torch ._inductor .config .reorder_for_compute_comm_overlap = False
129
+ aten_autobucketing_reordering_pass = partial (
130
+ aten_autobucketing_reordering_pass ,
131
+ configs = aten_autobucketing_config ,
132
+ )
133
+ torch ._inductor .config .post_grad_custom_post_pass = aten_autobucketing_reordering_pass
134
+
98
135
model = torch .compile (model , fullgraph = True )
99
136
100
137
return model
0 commit comments