Skip to content

Commit 4e0acbb

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Use CheckpointPath in get_x_checkpoint functions
Reviewed By: JKSenthil Differential Revision: D56427223
1 parent e1135d6 commit 4e0acbb

File tree

3 files changed

+110
-122
lines changed

3 files changed

+110
-122
lines changed

tests/utils/test_checkpoint.py

+58-31
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def test_latest_checkpoint_path(self) -> None:
234234
path_4 = os.path.join(temp_dir, "epoch_700")
235235
os.mkdir(path_4)
236236
self.assertEqual(
237-
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2
237+
get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
238+
path_2,
238239
)
239240

240241
@skip_if_not_distributed
@@ -284,7 +285,8 @@ def _latest_checkpoint_path_distributed() -> None:
284285
expected_path = path_container[0]
285286
tc.assertIsNotNone(expected_path)
286287
tc.assertEqual(
287-
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
288+
get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
289+
expected_path,
288290
)
289291

290292
if is_rank0:
@@ -368,7 +370,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
368370
# compares set equality since order of returned dirpaths is not guaranteed
369371
# in _retrieve_checkpoint_dirpaths
370372
self.assertEqual(
371-
set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
373+
{
374+
str(x)
375+
for x in _retrieve_checkpoint_dirpaths(
376+
temp_dir, metadata_fname=None
377+
)
378+
},
372379
{os.path.join(temp_dir, path) for path in paths[:-1]},
373380
)
374381
self.assertEqual(
@@ -382,9 +389,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
382389
pass
383390

384391
self.assertEqual(
385-
set(
386-
_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata")
387-
),
392+
{
393+
str(x)
394+
for x in _retrieve_checkpoint_dirpaths(
395+
temp_dir, metadata_fname=".metadata"
396+
)
397+
},
388398
{os.path.join(temp_dir, paths[2])},
389399
)
390400

@@ -394,30 +404,36 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
394404
"""
395405
with tempfile.TemporaryDirectory() as temp_dir:
396406
paths = [
397-
"epoch_0_step_10_val_loss=10",
398-
"epoch_1_step_10_val_loss=5",
407+
"epoch_0_step_10_val_loss=10.0",
408+
"epoch_1_step_10_val_loss=5.0",
399409
"epoch_2_step_10",
400410
"epoch_0_step_5",
401-
"epoch_0_step_6_train_loss=13",
411+
"epoch_0_step_6_train_loss=13.0",
402412
]
403413
for path in paths:
404414
os.mkdir(os.path.join(temp_dir, path))
405415
# make last path a file instead of a directory
406-
with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"):
416+
with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3.0"), "w"):
407417
pass
408418

409419
# compares set equality since order of returned dirpaths is not guaranteed
410420
# in _retrieve_checkpoint_dirpaths
411421
self.assertEqual(
412-
set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
422+
{
423+
str(x)
424+
for x in _retrieve_checkpoint_dirpaths(
425+
temp_dir, metadata_fname=None
426+
)
427+
},
413428
{os.path.join(temp_dir, path) for path in paths},
414429
)
415430
self.assertEqual(
416-
set(
417-
_retrieve_checkpoint_dirpaths(
431+
{
432+
str(x)
433+
for x in _retrieve_checkpoint_dirpaths(
418434
temp_dir, metadata_fname=None, metric_name="val_loss"
419435
)
420-
),
436+
},
421437
{
422438
os.path.join(temp_dir, path) for path in paths[:2]
423439
}, # since last path is a file
@@ -433,11 +449,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
433449
pass
434450

435451
self.assertEqual(
436-
set(
437-
_retrieve_checkpoint_dirpaths(
452+
{
453+
str(x)
454+
for x in _retrieve_checkpoint_dirpaths(
438455
temp_dir, metadata_fname=".metadata", metric_name="val_loss"
439456
)
440-
),
457+
},
441458
{os.path.join(temp_dir, paths[1])},
442459
)
443460

@@ -467,7 +484,7 @@ def create_tmp_dir() -> str:
467484
os.mkdir(path2)
468485
torch.distributed.barrier()
469486

470-
ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir)
487+
ckpt_dirpaths = [str(x) for x in get_checkpoint_dirpaths(temp_dir)]
471488
tc = unittest.TestCase()
472489
tc.assertEqual(set(ckpt_dirpaths), {path1, path2})
473490

@@ -492,7 +509,7 @@ def test_get_checkpoint_dirpaths(self) -> None:
492509
os.mkdir(path3)
493510

494511
self.assertEqual(
495-
set(get_checkpoint_dirpaths(temp_dir)),
512+
{str(x) for x in get_checkpoint_dirpaths(temp_dir)},
496513
{path1, path2, path3},
497514
)
498515

@@ -505,7 +522,10 @@ def test_get_checkpoint_dirpaths(self) -> None:
505522
os.mkdir(path3)
506523

507524
self.assertEqual(
508-
set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")),
525+
{
526+
str(x)
527+
for x in get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")
528+
},
509529
{path1, path2, path3},
510530
)
511531

@@ -519,20 +539,27 @@ def test_checkpoint_sorting_utils(self) -> None:
519539
"""
520540
Tests the sort utilities
521541
"""
522-
paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"]
523-
self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]])
542+
paths = [
543+
"foo/epoch_1_step_20",
544+
"foo/epoch_4_step_130",
545+
"foo/epoch_0_step_10_val_loss=10.0",
546+
]
547+
ckpts = [CheckpointPath.from_str(x) for x in paths]
548+
sorted_paths = [str(x) for x in _sort_by_recency(ckpts)]
549+
self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]])
524550

525551
paths = [
526-
"epoch_1_step_20_val_loss=0.09",
527-
"epoch_4_step_130_val_loss=29",
528-
"epoch_0_step_10_val_loss=10",
552+
"foo/epoch_1_step_20_val_loss=0.09",
553+
"foo/epoch_4_step_130_val_loss=29.0",
554+
"foo/epoch_0_step_10_val_loss=10.0",
529555
]
530-
self.assertEqual(
531-
_sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]]
532-
)
533-
self.assertEqual(
534-
_sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]]
535-
)
556+
ckpts = [CheckpointPath.from_str(x) for x in paths]
557+
558+
sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="min")]
559+
self.assertEqual(sorted_paths, [paths[1], paths[2], paths[0]])
560+
561+
sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="max")]
562+
self.assertEqual(sorted_paths, [paths[0], paths[2], paths[1]])
536563

537564
def test_delete_checkpoint(self) -> None:
538565
"""

torchtnt/framework/callbacks/base_checkpointer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,14 @@ def __init__(
124124

125125
# sort by metric value if doing best checkpoint, else by recency
126126
if best_checkpoint_config:
127-
self._ckpt_dirpaths = _sort_by_metric_value(
127+
ckpt_dirpaths = _sort_by_metric_value(
128128
ckpt_dirpaths, mode=best_checkpoint_config.mode
129129
)
130130
else:
131-
self._ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)
131+
ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths)
132+
133+
# TODO Remove this when using CheckpointManager
134+
self._ckpt_dirpaths = [str(x) for x in ckpt_dirpaths]
132135

133136
self._process_group: Optional[dist.ProcessGroup] = None
134137
self._setup_gloo_pg(process_group)

0 commit comments

Comments
 (0)