Skip to content

Commit 216d2c0

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Implement custom get/setstate for CheckpointPath
Reviewed By: JKSenthil Differential Revision: D56654810
1 parent 698d4d0 commit 216d2c0

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

tests/utils/test_checkpoint.py

+16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88
import os
9+
import pickle
910
import shutil
1011
import tempfile
1112
import unittest
@@ -173,6 +174,21 @@ def test_compare_by_optimality(self) -> None:
173174
self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
174175
self.assertFalse(larger.more_optimal_than(smaller, mode="min"))
175176

177+
def test_pickling(self) -> None:
178+
for path in (
179+
"foo/epoch_0_step_1",
180+
"file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
181+
):
182+
ckpt = CheckpointPath.from_str(path)
183+
184+
pickled = pickle.dumps(ckpt)
185+
186+
# Don't test equality because of custom protocol
187+
self.assertTrue(path in str(pickled))
188+
189+
unpickled = pickle.loads(pickled)
190+
self.assertEqual(unpickled, ckpt)
191+
176192

177193
class CheckpointUtilsTest(unittest.TestCase):
178194
@staticmethod

tests/utils/test_checkpoint_gpu.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import os
10+
import shutil
11+
import tempfile
12+
import unittest
13+
14+
import torch.distributed as dist
15+
from torchtnt.utils import init_from_env
16+
from torchtnt.utils.checkpoint import get_checkpoint_dirpaths
17+
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
18+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
19+
20+
21+
class TestCheckpointUtilsGPU(unittest.TestCase):
22+
23+
@skip_if_not_distributed
24+
@skip_if_not_gpu
25+
def test_get_checkpoint_dirpaths_distributed(self) -> None:
26+
spawn_multi_process(
27+
2,
28+
"nccl",
29+
self._test_get_checkpoint_dirpaths,
30+
)
31+
32+
@staticmethod
33+
def _test_get_checkpoint_dirpaths() -> None:
34+
"""
35+
Tests retrieving checkpoint directories from a given root directory
36+
using NCCL on GPUs with custom state for pickling.
37+
"""
38+
init_from_env()
39+
paths = [
40+
"epoch_0_step_10",
41+
"epoch_1_step_10_val_loss=10.5",
42+
"epoch_2_step_10",
43+
"epoch_0_step_5",
44+
"epoch_0_step_6_acc=0.03",
45+
"epoch_0_step_3",
46+
]
47+
48+
if get_global_rank() == 0:
49+
temp_dir = tempfile.mkdtemp()
50+
for path in paths:
51+
os.mkdir(os.path.join(temp_dir, path))
52+
else:
53+
temp_dir = None
54+
55+
tc = unittest.TestCase()
56+
# Only rank 0 will know about temp_dir
57+
if get_global_rank() != 0:
58+
tc.assertIsNone(temp_dir)
59+
60+
ckpt_dirpaths = get_checkpoint_dirpaths(
61+
temp_dir, process_group=dist.group.WORLD
62+
)
63+
64+
# Broadcast temp_dir to verify successful execution
65+
temp_dir = [temp_dir] if get_global_rank() == 0 else [None]
66+
dist.broadcast_object_list(temp_dir, src=0, group=dist.group.WORLD)
67+
temp_dir = temp_dir[0]
68+
tc.assertIsNotNone(temp_dir)
69+
70+
tc.assertEqual(
71+
{str(x) for x in ckpt_dirpaths},
72+
{os.path.join(temp_dir, path) for path in paths},
73+
)
74+
75+
if get_global_rank() == 0:
76+
shutil.rmtree(temp_dir)

torchtnt/utils/checkpoint.py

+19
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,25 @@ def __eq__(self, other: "CheckpointPath") -> bool:
198198
def __gt__(self, other: "CheckpointPath") -> bool:
199199
return self.newer_than(other)
200200

201+
def __getstate__(self) -> str:
202+
# Lightweight pickling to avoid broadcast errors
203+
return self.path
204+
205+
def __setstate__(self, state: str) -> None:
206+
# Match regex directly to avoid creating a new instance with `from_str`
207+
path_match = self.PATH_REGEX.match(state)
208+
assert path_match, f"Malformed checkpoint found when unpickling: {state}"
209+
210+
dirpath, epoch, step, metric_name, metric_value = path_match.groups()
211+
self.dirpath = dirpath.rstrip("/")
212+
self.epoch = int(epoch)
213+
self.step = int(step)
214+
self.metric_data = (
215+
MetricData(name=metric_name, value=float(metric_value))
216+
if metric_name and metric_value
217+
else None
218+
)
219+
201220

202221
@rank_zero_read_and_broadcast
203222
def get_latest_checkpoint_path(

0 commit comments

Comments
 (0)