Skip to content

Commit 0943771

Browse files
authored
Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation (#1683)
Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions: 1. If compile.enable is set to False, SAC will ignore any torch.compile calls within the SAC region. 2. If compile.enable is True but the transformer block includes a Mixture of Experts (MoE) module. To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False. FIX (workaround) #1631
1 parent 7e805a9 commit 0943771

File tree

9 files changed

+278
-131
lines changed

9 files changed

+278
-131
lines changed

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
# for selective op activation checkpointing
18-
_save_list = {
18+
_op_sac_save_list = {
1919
torch.ops.aten.mm.default,
2020
torch.ops.aten._scaled_dot_product_efficient_attention.default,
2121
torch.ops.aten._scaled_dot_product_flash_attention.default,
@@ -85,7 +85,7 @@ def get_bw_flops(model_fn):
8585
ac_config_no_force,
8686
model_compile_enabled=False,
8787
use_flex_attn=False,
88-
save_list=_save_list,
88+
op_sac_save_list=_op_sac_save_list,
8989
)
9090
flops_selective_ac = get_bw_flops(model_selective_ac)
9191

@@ -103,7 +103,7 @@ def get_bw_flops(model_fn):
103103
ac_config_with_force_first,
104104
model_compile_enabled=False,
105105
use_flex_attn=False,
106-
save_list=_save_list,
106+
op_sac_save_list=_op_sac_save_list,
107107
)
108108
flops_with_force_first = get_bw_flops(model_with_force_first)
109109

@@ -120,7 +120,7 @@ def get_bw_flops(model_fn):
120120
ac_config_with_force_last,
121121
model_compile_enabled=False,
122122
use_flex_attn=False,
123-
save_list=_save_list,
123+
op_sac_save_list=_op_sac_save_list,
124124
)
125125
flops_with_force_last = get_bw_flops(model_with_force_last)
126126

@@ -135,7 +135,7 @@ def get_bw_flops(model_fn):
135135
ac_config_full_ac,
136136
model_compile_enabled=False,
137137
use_flex_attn=False,
138-
save_list=_save_list,
138+
op_sac_save_list=_op_sac_save_list,
139139
)
140140
flops_full_ac = get_bw_flops(model_with_full_ac)
141141

@@ -178,7 +178,7 @@ def get_act_mem(model_fn):
178178
ac_config_no_force,
179179
model_compile_enabled=False,
180180
use_flex_attn=False,
181-
save_list=_save_list,
181+
op_sac_save_list=_op_sac_save_list,
182182
)
183183
mem_selective_ac = get_act_mem(model_selective_ac)
184184

@@ -195,7 +195,7 @@ def get_act_mem(model_fn):
195195
ac_config_with_force_first,
196196
model_compile_enabled=False,
197197
use_flex_attn=False,
198-
save_list=_save_list,
198+
op_sac_save_list=_op_sac_save_list,
199199
)
200200
mem_with_force_first = get_act_mem(model_with_force_first)
201201

@@ -211,7 +211,7 @@ def get_act_mem(model_fn):
211211
ac_config_with_force_last,
212212
model_compile_enabled=False,
213213
use_flex_attn=False,
214-
save_list=_save_list,
214+
op_sac_save_list=_op_sac_save_list,
215215
)
216216
mem_with_force_last = get_act_mem(model_with_force_last)
217217

@@ -225,7 +225,7 @@ def get_act_mem(model_fn):
225225
ac_config_full_ac,
226226
model_compile_enabled=False,
227227
use_flex_attn=False,
228-
save_list=_save_list,
228+
op_sac_save_list=_op_sac_save_list,
229229
)
230230
mem_full_ac = get_act_mem(model_with_full_ac)
231231

@@ -252,7 +252,7 @@ def test_correctness(self):
252252
),
253253
model_compile_enabled=False,
254254
use_flex_attn=False,
255-
save_list=_save_list,
255+
op_sac_save_list=_op_sac_save_list,
256256
)
257257
model_force_first = ToyModule()
258258
model_force_first.load_state_dict(model_no_ac.state_dict())
@@ -265,7 +265,7 @@ def test_correctness(self):
265265
),
266266
model_compile_enabled=False,
267267
use_flex_attn=False,
268-
save_list=_save_list,
268+
op_sac_save_list=_op_sac_save_list,
269269
)
270270

271271
model_force_last = ToyModule()
@@ -279,7 +279,7 @@ def test_correctness(self):
279279
),
280280
model_compile_enabled=False,
281281
use_flex_attn=False,
282-
save_list=_save_list,
282+
op_sac_save_list=_op_sac_save_list,
283283
)
284284

285285
def run_fwd_bwd(model, batch):

0 commit comments

Comments
 (0)