Skip to content

Commit bdefa26

Browse files
rohan-varmapytorchmergebot
authored andcommitted
[RFC] Separate CPU offload activation to its own wrapper (pytorch#85459)
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept. Now, offload to CPU + checkpoint can be composed together, such as ``` # apply AC to transformer layers apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer)) # offload the rest of activations to CPU model = offload_wrapper(model) ``` Will polish / add tests if this proposal sounds good. Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/) Pull Request resolved: pytorch#85459 Approved by: https://github.com/awgu
1 parent 100113b commit bdefa26

File tree

3 files changed

+151
-98
lines changed

3 files changed

+151
-98
lines changed

test/distributed/fsdp/test_checkpoint_wrapper.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch.nn as nn
88
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
99
checkpoint_wrapper,
10+
offload_wrapper,
1011
apply_activation_checkpointing,
1112
CheckpointWrapper,
13+
OffloadWrapper,
1214
CheckpointImpl
1315
)
1416

@@ -21,6 +23,9 @@
2123

2224
import unittest
2325

26+
_SAVED_PREFIX = '_saved_'
27+
GRAD_FN_NEXT_FUNCTIONS = 'next_functions'
28+
2429
class CheckpointWrapperTest(TestCase):
2530
def setUp(self):
2631
super().setUp()
@@ -72,11 +77,14 @@ def forward(self, a, b, c=None, d=None, **kwargs):
7277
for wrapper in [
7378
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
7479
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
75-
partial(checkpoint_wrapper, offload_to_cpu=True),
80+
offload_wrapper,
7681
]:
7782
with self.subTest(wrapper=wrapper):
7883
model = wrapper(MyModel())
79-
self.assertTrue(isinstance(model, CheckpointWrapper))
84+
if wrapper == offload_wrapper:
85+
self.assertTrue(isinstance(model, OffloadWrapper))
86+
else:
87+
self.assertTrue(isinstance(model, CheckpointWrapper))
8088
# Verify kwargs can be passed in
8189
inp = torch.ones(4, 10, requires_grad=True)
8290
out = model(inp, inp, c=inp, d=inp, e=inp, f=inp)
@@ -211,6 +219,7 @@ def check_fn(l):
211219
for wrapper in [
212220
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
213221
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
222+
offload_wrapper,
214223
]:
215224
model = MyModel()
216225
if n_linear is None:
@@ -223,12 +232,12 @@ def check_fn(l):
223232
model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn
224233
)
225234
n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules())
226-
n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules())
235+
n_checkpointed = sum(1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 for x in model.modules())
227236
self.assertEqual(n_checkpointed, n_linear_wrapped)
228237
self.assertEqual(n_linear, n_linear_wrapped)
229238
for j in range(3):
230-
self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper))
231-
self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper))
239+
self.assertTrue(isinstance(model.seq[j].lin, (CheckpointWrapper, OffloadWrapper)))
240+
self.assertTrue(isinstance(model.seq[j].nested_linear[0], (CheckpointWrapper, OffloadWrapper)))
232241

233242
inp = torch.randn(4, 10, requires_grad=True)
234243
for i in range(6):
@@ -276,7 +285,7 @@ def testing_cpu_offload_unpack_hook(packed):
276285
orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
277286
torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init
278287

279-
model = checkpoint_wrapper(model, offload_to_cpu=True)
288+
model = offload_wrapper(model)
280289

281290
inp = torch.randn(3, 10, device='cuda')
282291
loss = model(inp).sum()
@@ -286,7 +295,7 @@ def testing_cpu_offload_unpack_hook(packed):
286295

287296
def dfs(grad_fn):
288297
for e in dir(grad_fn):
289-
if not e.startswith('_saved_'):
298+
if not e.startswith(_SAVED_PREFIX):
290299
continue
291300

292301
saved = getattr(grad_fn, e)
@@ -295,7 +304,7 @@ def dfs(grad_fn):
295304
nonlocal offload_verified
296305
offload_verified = True
297306

298-
if hasattr(grad_fn, 'next_functions'):
307+
if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS):
299308
for next_grad_fn, _ in grad_fn.next_functions:
300309
dfs(next_grad_fn)
301310

test/distributed/fsdp/test_fsdp_checkpoint.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from functools import partial
66

77
import torch
8+
import torch.distributed as dist
89
import torch.nn as nn
910
from torch.distributed.fsdp.fully_sharded_data_parallel import (
1011
FullyShardedDataParallel as FSDP,
1112
CPUOffload,
1213
)
1314
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1415
checkpoint_wrapper,
16+
offload_wrapper,
1517
)
1618
from torch.testing._internal.common_distributed import (
1719
skip_if_lt_x_gpu,
@@ -65,9 +67,10 @@ def __init__(
6567
l3 = nn.Linear(3, 3).cuda()
6668

6769
if checkpoint_layer:
68-
ckpt_wrapper = partial(
69-
checkpoint_wrapper, offload_to_cpu=offload_activations
70-
)
70+
if offload_activations:
71+
ckpt_wrapper = offload_wrapper
72+
else:
73+
ckpt_wrapper = checkpoint_wrapper
7174

7275
l1 = ckpt_wrapper(l1)
7376
l2 = ckpt_wrapper(l2)
@@ -110,11 +113,15 @@ def _verify_parity(self, losses, outputs, models):
110113
@parametrize("offload_activations", [True, False])
111114
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
112115
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
113-
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
116+
if offload_activations:
117+
wrapper_to_use = offload_wrapper
118+
else:
119+
wrapper_to_use = checkpoint_wrapper
120+
121+
ckpt_sequential_wrapped_fsdp = wrapper_to_use(
114122
TestFSDPCheckpoint.SequentialModule(
115123
wrap_fsdp=True, cpu_offload=cpu_offload
116124
),
117-
offload_to_cpu=offload_activations,
118125
)
119126
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
120127
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
@@ -153,6 +160,8 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
153160

154161
self._verify_parity(losses, outputs, models)
155162

163+
dist.barrier()
164+
156165
@skip_if_lt_x_gpu(2)
157166
@parametrize(
158167
"cpu_offload",
@@ -166,13 +175,17 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
166175
# Runs FSDP with no checkpointing
167176
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
168177
# Runs checkpoint-wrapped FSDP
169-
checkpointed_fsdp = checkpoint_wrapper(
178+
if offload_activations:
179+
wrapper_to_use = offload_wrapper
180+
else:
181+
wrapper_to_use = checkpoint_wrapper
182+
183+
checkpointed_fsdp = wrapper_to_use(
170184
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
171-
offload_to_cpu=offload_activations,
172185
)
173186
# Runs FSDP-wrapped checkpointed module
174187
fsdp_wrapped_checkpoint = FSDP(
175-
checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
188+
wrapper_to_use(deepcopy(seq)),
176189
cpu_offload=cpu_offload,
177190
)
178191
# Runs FSDP with manual calls to checkpoint.
@@ -220,6 +233,8 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
220233

221234
self._verify_parity(losses, outputs, models)
222235

236+
dist.barrier()
237+
223238
instantiate_parametrized_tests(TestFSDPCheckpoint)
224239

225240
if __name__ == "__main__":

0 commit comments

Comments
 (0)