@@ -234,7 +234,8 @@ def test_latest_checkpoint_path(self) -> None:
234
234
path_4 = os .path .join (temp_dir , "epoch_700" )
235
235
os .mkdir (path_4 )
236
236
self .assertEqual (
237
- get_latest_checkpoint_path (temp_dir , METADATA_FNAME ), path_2
237
+ get_latest_checkpoint_path (temp_dir , METADATA_FNAME ),
238
+ path_2 ,
238
239
)
239
240
240
241
@skip_if_not_distributed
@@ -284,7 +285,8 @@ def _latest_checkpoint_path_distributed() -> None:
284
285
expected_path = path_container [0 ]
285
286
tc .assertIsNotNone (expected_path )
286
287
tc .assertEqual (
287
- get_latest_checkpoint_path (temp_dir , METADATA_FNAME ), expected_path
288
+ get_latest_checkpoint_path (temp_dir , METADATA_FNAME ),
289
+ expected_path ,
288
290
)
289
291
290
292
if is_rank0 :
@@ -368,7 +370,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
368
370
# compares set equality since order of returned dirpaths is not guaranteed
369
371
# in _retrieve_checkpoint_dirpaths
370
372
self .assertEqual (
371
- set (_retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = None )),
373
+ {
374
+ str (x )
375
+ for x in _retrieve_checkpoint_dirpaths (
376
+ temp_dir , metadata_fname = None
377
+ )
378
+ },
372
379
{os .path .join (temp_dir , path ) for path in paths [:- 1 ]},
373
380
)
374
381
self .assertEqual (
@@ -382,9 +389,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
382
389
pass
383
390
384
391
self .assertEqual (
385
- set (
386
- _retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = ".metadata" )
387
- ),
392
+ {
393
+ str (x )
394
+ for x in _retrieve_checkpoint_dirpaths (
395
+ temp_dir , metadata_fname = ".metadata"
396
+ )
397
+ },
388
398
{os .path .join (temp_dir , paths [2 ])},
389
399
)
390
400
@@ -394,30 +404,36 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
394
404
"""
395
405
with tempfile .TemporaryDirectory () as temp_dir :
396
406
paths = [
397
- "epoch_0_step_10_val_loss=10" ,
398
- "epoch_1_step_10_val_loss=5" ,
407
+ "epoch_0_step_10_val_loss=10.0 " ,
408
+ "epoch_1_step_10_val_loss=5.0 " ,
399
409
"epoch_2_step_10" ,
400
410
"epoch_0_step_5" ,
401
- "epoch_0_step_6_train_loss=13" ,
411
+ "epoch_0_step_6_train_loss=13.0 " ,
402
412
]
403
413
for path in paths :
404
414
os .mkdir (os .path .join (temp_dir , path ))
405
415
# make last path a file instead of a directory
406
- with open (os .path .join (temp_dir , "epoch_0_step_3_val_loss=3" ), "w" ):
416
+ with open (os .path .join (temp_dir , "epoch_0_step_3_val_loss=3.0 " ), "w" ):
407
417
pass
408
418
409
419
# compares set equality since order of returned dirpaths is not guaranteed
410
420
# in _retrieve_checkpoint_dirpaths
411
421
self .assertEqual (
412
- set (_retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = None )),
422
+ {
423
+ str (x )
424
+ for x in _retrieve_checkpoint_dirpaths (
425
+ temp_dir , metadata_fname = None
426
+ )
427
+ },
413
428
{os .path .join (temp_dir , path ) for path in paths },
414
429
)
415
430
self .assertEqual (
416
- set (
417
- _retrieve_checkpoint_dirpaths (
431
+ {
432
+ str (x )
433
+ for x in _retrieve_checkpoint_dirpaths (
418
434
temp_dir , metadata_fname = None , metric_name = "val_loss"
419
435
)
420
- ) ,
436
+ } ,
421
437
{
422
438
os .path .join (temp_dir , path ) for path in paths [:2 ]
423
439
}, # since last path is a file
@@ -433,11 +449,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
433
449
pass
434
450
435
451
self .assertEqual (
436
- set (
437
- _retrieve_checkpoint_dirpaths (
452
+ {
453
+ str (x )
454
+ for x in _retrieve_checkpoint_dirpaths (
438
455
temp_dir , metadata_fname = ".metadata" , metric_name = "val_loss"
439
456
)
440
- ) ,
457
+ } ,
441
458
{os .path .join (temp_dir , paths [1 ])},
442
459
)
443
460
@@ -467,7 +484,7 @@ def create_tmp_dir() -> str:
467
484
os .mkdir (path2 )
468
485
torch .distributed .barrier ()
469
486
470
- ckpt_dirpaths = get_checkpoint_dirpaths (temp_dir )
487
+ ckpt_dirpaths = [ str ( x ) for x in get_checkpoint_dirpaths (temp_dir )]
471
488
tc = unittest .TestCase ()
472
489
tc .assertEqual (set (ckpt_dirpaths ), {path1 , path2 })
473
490
@@ -492,7 +509,7 @@ def test_get_checkpoint_dirpaths(self) -> None:
492
509
os .mkdir (path3 )
493
510
494
511
self .assertEqual (
495
- set ( get_checkpoint_dirpaths (temp_dir )) ,
512
+ { str ( x ) for x in get_checkpoint_dirpaths (temp_dir )} ,
496
513
{path1 , path2 , path3 },
497
514
)
498
515
@@ -505,7 +522,10 @@ def test_get_checkpoint_dirpaths(self) -> None:
505
522
os .mkdir (path3 )
506
523
507
524
self .assertEqual (
508
- set (get_checkpoint_dirpaths (temp_dir , metric_name = "val_loss" )),
525
+ {
526
+ str (x )
527
+ for x in get_checkpoint_dirpaths (temp_dir , metric_name = "val_loss" )
528
+ },
509
529
{path1 , path2 , path3 },
510
530
)
511
531
@@ -519,20 +539,27 @@ def test_checkpoint_sorting_utils(self) -> None:
519
539
"""
520
540
Tests the sort utilities
521
541
"""
522
- paths = ["epoch_1_step_20" , "epoch_4_step_130" , "epoch_0_step_10_val_loss=10" ]
523
- self .assertEqual (_sort_by_recency (paths ), [paths [2 ], paths [0 ], paths [1 ]])
542
+ paths = [
543
+ "foo/epoch_1_step_20" ,
544
+ "foo/epoch_4_step_130" ,
545
+ "foo/epoch_0_step_10_val_loss=10.0" ,
546
+ ]
547
+ ckpts = [CheckpointPath .from_str (x ) for x in paths ]
548
+ sorted_paths = [str (x ) for x in _sort_by_recency (ckpts )]
549
+ self .assertEqual (sorted_paths , [paths [2 ], paths [0 ], paths [1 ]])
524
550
525
551
paths = [
526
- "epoch_1_step_20_val_loss=0.09" ,
527
- "epoch_4_step_130_val_loss=29" ,
528
- "epoch_0_step_10_val_loss=10" ,
552
+ "foo/ epoch_1_step_20_val_loss=0.09" ,
553
+ "foo/ epoch_4_step_130_val_loss=29.0 " ,
554
+ "foo/ epoch_0_step_10_val_loss=10.0 " ,
529
555
]
530
- self .assertEqual (
531
- _sort_by_metric_value (paths , mode = "min" ), [paths [1 ], paths [2 ], paths [0 ]]
532
- )
533
- self .assertEqual (
534
- _sort_by_metric_value (paths , mode = "max" ), [paths [0 ], paths [2 ], paths [1 ]]
535
- )
556
+ ckpts = [CheckpointPath .from_str (x ) for x in paths ]
557
+
558
+ sorted_paths = [str (x ) for x in _sort_by_metric_value (ckpts , mode = "min" )]
559
+ self .assertEqual (sorted_paths , [paths [1 ], paths [2 ], paths [0 ]])
560
+
561
+ sorted_paths = [str (x ) for x in _sort_by_metric_value (ckpts , mode = "max" )]
562
+ self .assertEqual (sorted_paths , [paths [0 ], paths [2 ], paths [1 ]])
536
563
537
564
def test_delete_checkpoint (self ) -> None :
538
565
"""
0 commit comments