From 47438973b803fb33e59466423a4b2b803087ac56 Mon Sep 17 00:00:00 2001
From: Diego Urgell <diegourgell@fb.com>
Date: Fri, 26 Apr 2024 01:25:57 -0700
Subject: [PATCH 1/2] Add CheckpointPath abstraction in utils/checkpoint.py

Differential Revision: D56260188
---
 tests/utils/test_checkpoint.py | 143 ++++++++++++++++++++++++++
 torchtnt/utils/__init__.py     |   3 +
 torchtnt/utils/checkpoint.py   | 178 +++++++++++++++++++++++++++++++++
 3 files changed, 324 insertions(+)
 create mode 100644 tests/utils/test_checkpoint.py
 create mode 100644 torchtnt/utils/checkpoint.py

diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py
new file mode 100644
index 0000000000..91778df115
--- /dev/null
+++ b/tests/utils/test_checkpoint.py
@@ -0,0 +1,143 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+import unittest
+
+from torchtnt.utils.checkpoint import CheckpointPath, MetricData
+
+
+class CheckpointPathTest(unittest.TestCase):
+    def test_from_str(self) -> None:
+        # invalid paths
+        malformed_paths = [
+            "foo/step_20",
+            "foo/epoch_50",
+            "epoch_30",
+            "foo/epoch_20_step",
+            "foo/epoch_20_step_30_val_loss=1a",
+            "foo/epoch_2_step_15_mean=hello",
+            "foo/epoch_2.6_step_23",
+        ]
+        for path in malformed_paths:
+            with self.assertRaisesRegex(
+                ValueError, f"Attempted to parse malformed checkpoint path: {path}"
+            ):
+                CheckpointPath.from_str(path)
+
+        # valid paths
+        valid_paths = [
+            ("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)),
+            (
+                "foo/epoch_14_step_3_mean=15.0",
+                CheckpointPath(
+                    "foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0)
+                ),
+            ),
+            (
+                "foo/epoch_14_step_3_loss=-27.35",
+                CheckpointPath(
+                    "foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
+                ),
+            ),
+            (
+                "/foo/epoch_14_step_3_loss=-27.35",
+                CheckpointPath(
+                    "/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
+                ),
+            ),
+            (
+                "foo/bar/epoch_23_step_31_mean_loss_squared=0.0",
+                CheckpointPath(
+                    "foo/bar/",
+                    epoch=23,
+                    step=31,
+                    metric_data=MetricData("mean_loss_squared", 0.0),
+                ),
+            ),
+            (
+                "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
+                CheckpointPath(
+                    "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61",
+                    epoch=2,
+                    step=1,
+                    metric_data=MetricData("acc", 0.98),
+                ),
+            ),
+        ]
+        for path, expected_ckpt in valid_paths:
+            parsed_ckpt = CheckpointPath.from_str(path)
+            self.assertEqual(parsed_ckpt, expected_ckpt)
+            self.assertEqual(parsed_ckpt.path, path)
+
+        # with a trailing slash
+        ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/")
+        self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1))
+        self.assertEqual(ckpt.path, "foo/epoch_0_step_1")
+
+    def test_compare_by_recency(self) -> None:
+        old = CheckpointPath("foo", epoch=0, step=1)
+        new = CheckpointPath("foo", epoch=1, step=1)
+        self.assertTrue(new.newer_than(old))
+        self.assertFalse(old.newer_than(new))
+        self.assertFalse(new == old)
+
+        old = CheckpointPath("foo", epoch=3, step=5)
+        new = CheckpointPath("foo", epoch=3, step=9)
+        self.assertTrue(new.newer_than(old))
+        self.assertFalse(old.newer_than(new))
+        self.assertFalse(new == old)
+
+        twin1 = CheckpointPath(
+            "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
+        )
+        almost_twin = CheckpointPath(
+            "foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0)
+        )
+
+        self.assertFalse(twin1.newer_than(almost_twin))
+        self.assertFalse(almost_twin.newer_than(twin1))
+        self.assertFalse(twin1 == almost_twin)
+
+        twin2 = CheckpointPath(
+            "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
+        )
+        self.assertTrue(twin1 == twin2)
+
+    def test_compare_by_optimality(self) -> None:
+        # not both metric aware
+        ckpt1 = CheckpointPath("foo", epoch=0, step=1)
+        ckpt2 = CheckpointPath("foo", epoch=1, step=1)
+        ckpt3 = CheckpointPath(
+            "foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0)
+        )
+        for ckpt in [ckpt2, ckpt3]:
+            with self.assertRaisesRegex(
+                AssertionError,
+                "Attempted to compare optimality of non metric-aware checkpoints",
+            ):
+                ckpt1.more_optimal_than(ckpt, mode="min")
+
+        # tracking different metrics
+        ckpt4 = CheckpointPath(
+            "foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0)
+        )
+        with self.assertRaisesRegex(
+            AssertionError,
+            "Attempted to compare optimality of checkpoints tracking different metrics",
+        ):
+            ckpt3.more_optimal_than(ckpt4, mode="min")
+
+        smaller = CheckpointPath(
+            "foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0)
+        )
+        larger = CheckpointPath(
+            "foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0)
+        )
+        self.assertTrue(larger.more_optimal_than(smaller, mode="max"))
+        self.assertFalse(smaller.more_optimal_than(larger, mode="max"))
+        self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
+        self.assertFalse(larger.more_optimal_than(smaller, mode="min"))
diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py
index cb973c13a6..c0ad4c3b8d 100644
--- a/torchtnt/utils/__init__.py
+++ b/torchtnt/utils/__init__.py
@@ -6,6 +6,7 @@
 
 # pyre-strict
 
+from .checkpoint import CheckpointPath, MetricData
 from .device import (
     copy_data_to_device,
     CPUStats,
@@ -148,4 +149,6 @@
     "is_windows",
     "get_pet_launch_config",
     "spawn_multi_process",
+    "CheckpointPath",
+    "MetricData",
 ]
diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py
new file mode 100644
index 0000000000..234464b6a7
--- /dev/null
+++ b/torchtnt/utils/checkpoint.py
@@ -0,0 +1,178 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+import os
+import re
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Literal, Optional, Pattern
+
+from pyre_extensions import none_throws
+
+
+@dataclass
+class MetricData:
+    """
+    Representation of a metric instance. Should provide both a metric name and it's value.
+    """
+
+    name: str
+    value: float
+
+
+@total_ordering
+class CheckpointPath:
+    """
+    Representation of a checkpoint path. Handles parsing and serialization of the specific path format.
+    Currently, the basic compliant path format is: <dirpath>/epoch_<epoch>_step_<step>
+    If a metric is being tracked, it's added to the name: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>
+
+    This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by
+    metric can be done by extracting the metric value from the metric_data attribute.
+    """
+
+    PATH_REGEX: Pattern = re.compile(
+        r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$"
+    )
+
+    def __init__(
+        self,
+        dirpath: str,
+        epoch: int,
+        step: int,
+        metric_data: Optional[MetricData] = None,
+    ) -> None:
+        """
+        Args:
+            dirpath: The base directory path that checkpoints are saved in.
+            epoch: The epoch number of this checkpoint.
+            step: The step number of this checkpoint.
+            metric_data: Optional data about the metric being tracked. Should contain both metric name and value.
+        """
+        self.dirpath: str = dirpath.rstrip("/")
+        self.epoch = epoch
+        self.step = step
+        self.metric_data = metric_data
+
+    @classmethod
+    def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
+        """
+        Given a directory path, try to parse it and extract the checkpoint data.
+        The expected format is: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>,
+        where the metric name and value are optional.
+
+        Args:
+            checkpoint_path: The path to the checkpoint directory.
+
+        Returns:
+            A CheckpointPath instance if the path is valid, otherwise None.
+
+        Raises:
+            ValueError: If the path is malformed and can't be parsed.
+        """
+        path_match = cls.PATH_REGEX.match(checkpoint_path)
+        if not path_match:
+            raise ValueError(
+                f"Attempted to parse malformed checkpoint path: {checkpoint_path}."
+            )
+
+        dirpath, epoch, step, metric_name, metric_value = path_match.groups()
+        try:
+            metric_data: Optional[MetricData] = None
+            if metric_name:
+                metric_value_f = float(metric_value)
+                metric_data = MetricData(name=metric_name, value=metric_value_f)
+
+            return CheckpointPath(
+                dirpath=dirpath,
+                epoch=int(epoch),
+                step=int(step),
+                metric_data=metric_data,
+            )
+
+        except ValueError:
+            # Should never happen since path matches regex
+            raise ValueError(
+                f"Invalid data types found in checkpoint path: {checkpoint_path}."
+            )
+
+    @property
+    def path(self) -> str:
+        """
+        Returns:
+            The full path to the checkpoint directory.
+        """
+        name = f"epoch_{self.epoch}_step_{self.step}"
+        if self.metric_data:
+            name += f"_{self.metric_data.name}={self.metric_data.value}"
+
+        return os.path.join(self.dirpath, name)
+
+    def newer_than(self, other: "CheckpointPath") -> bool:
+        """
+        Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other.
+
+        Returns:
+            True if this checkpoint is newer than the other, otherwise False.
+        """
+        if self.epoch != other.epoch:
+            return self.epoch > other.epoch
+
+        return self.step > other.step
+
+    def more_optimal_than(
+        self, other: "CheckpointPath", mode: Literal["min", "max"]
+    ) -> bool:
+        """
+        Given another CheckpointPath instance, determine if this checkpoint is strictly more optimal than the other.
+        Optimality is determined by comparing the metric value of the two checkpoints. The mode indicates if the
+        metric value should be minimized or maximized. This only works for metric-aware checkpoints.
+
+        Args:
+            other: The other checkpoint path to compare against.
+            mode: The mode to use for comparison.
+
+        Returns:
+            True if this checkpoint is more optimal than the other, otherwise False.
+
+        Note: This expects that both checkpoints are metric-aware, and that they are tracking the same metric.
+        """
+
+        assert (
+            self.metric_data and other.metric_data
+        ), f"Attempted to compare optimality of non metric-aware checkpoints: {self} and {other}"
+
+        assert (
+            self.metric_data.name == other.metric_data.name
+        ), f"Attempted to compare optimality of checkpoints tracking different metrics: {self} and {other}"
+
+        if mode == "min":
+            return (
+                none_throws(self.metric_data).value
+                < none_throws(other.metric_data).value
+            )
+
+        return (
+            none_throws(self.metric_data).value > none_throws(other.metric_data).value
+        )
+
+    def __str__(self) -> str:
+        return self.path
+
+    def __repr__(self) -> str:
+        return f"CheckpointPath(dirpath={self.dirpath}, epoch={self.epoch}, step={self.step}, metric_data={self.metric_data})"
+
+    def __eq__(self, other: "CheckpointPath") -> bool:
+        return (
+            self.dirpath == other.dirpath
+            and self.epoch == other.epoch
+            and self.step == other.step
+            and self.metric_data == other.metric_data
+        )
+
+    def __gt__(self, other: "CheckpointPath") -> bool:
+        return self.newer_than(other)

From 888305f7115d0ed7059fc13cd765ec8e053a6083 Mon Sep 17 00:00:00 2001
From: Diego Urgell <diegourgell@meta.com>
Date: Fri, 26 Apr 2024 01:26:11 -0700
Subject: [PATCH 2/2] Move `get_x_checkpoint` functions to
 `utils/checkpoint.py`

Reviewed By: JKSenthil

Differential Revision: D56450720
---
 .../callbacks/test_base_checkpointer.py       |   2 +-
 .../callbacks/test_checkpoint_utils.py        | 395 -----------------
 tests/utils/test_checkpoint.py                | 408 +++++++++++++++++-
 .../framework/callbacks/_checkpoint_utils.py  | 254 +----------
 .../framework/callbacks/base_checkpointer.py  |  16 +-
 torchtnt/utils/__init__.py                    |  11 +-
 torchtnt/utils/checkpoint.py                  | 249 ++++++++++-
 7 files changed, 675 insertions(+), 660 deletions(-)

diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py
index fb8fea71d7..105d6052ea 100644
--- a/tests/framework/callbacks/test_base_checkpointer.py
+++ b/tests/framework/callbacks/test_base_checkpointer.py
@@ -249,7 +249,7 @@ def test_restore_from_latest_empty_dir(self) -> None:
                 self.assertEqual(
                     log.output,
                     [
-                        f"WARNING:torchtnt.framework.callbacks._checkpoint_utils:Input dirpath doesn't contain any subdirectories: {temp_dir}"
+                        f"WARNING:torchtnt.utils.checkpoint:Input dirpath doesn't contain any subdirectories: {temp_dir}"
                     ],
                 )
                 self.assertFalse(restored)
diff --git a/tests/framework/callbacks/test_checkpoint_utils.py b/tests/framework/callbacks/test_checkpoint_utils.py
index ac1a019a98..f917fcd942 100644
--- a/tests/framework/callbacks/test_checkpoint_utils.py
+++ b/tests/framework/callbacks/test_checkpoint_utils.py
@@ -6,411 +6,16 @@
 
 # pyre-strict
 
-import os
-import shutil
-import tempfile
 import unittest
 
-import torch
-import torch.distributed as dist
-from torch import nn
-from torchsnapshot import Snapshot
-from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
 from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
 
 from torchtnt.framework.callbacks._checkpoint_utils import (
-    _delete_checkpoint,
-    _metadata_exists,
     _prepare_app_state_for_checkpoint,
-    _retrieve_checkpoint_dirpaths,
-    _sort_by_metric_value,
-    _sort_by_recency,
-    get_best_checkpoint_path,
-    get_checkpoint_dirpaths,
-    get_latest_checkpoint_path,
-    rank_zero_read_and_broadcast,
 )
-from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process
-from torchtnt.utils.env import init_from_env
-from torchtnt.utils.fsspec import get_filesystem
-from torchtnt.utils.test_utils import skip_if_not_distributed
-
-METADATA_FNAME: str = ".metadata"
 
 
 class CheckpointUtilsTest(unittest.TestCase):
-    @staticmethod
-    def _create_snapshot_metadata(output_dir: str) -> None:
-        path = os.path.join(output_dir, METADATA_FNAME)
-        with open(path, "w"):
-            pass
-
-    def test_latest_checkpoint_path(self) -> None:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            self.assertIsNone(get_latest_checkpoint_path(temp_dir))
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            latest_path = os.path.join(temp_dir, "epoch_0_step_0")
-            os.mkdir(latest_path)
-            self.assertEqual(
-                get_latest_checkpoint_path(temp_dir),
-                latest_path,
-            )
-            self.assertEqual(
-                get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
-                None,
-            )
-            self._create_snapshot_metadata(latest_path)
-            self.assertEqual(
-                get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
-                latest_path,
-            )
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            path_1 = os.path.join(temp_dir, "epoch_0_step_0")
-            os.mkdir(path_1)
-            self._create_snapshot_metadata(path_1)
-            path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002")
-            os.mkdir(path_2)
-            self._create_snapshot_metadata(path_2)
-
-            # Missing metadata file
-            path_3 = os.path.join(temp_dir, "epoch_1_step_100")
-            os.mkdir(path_3)
-
-            # Ill-formatted name
-            path_4 = os.path.join(temp_dir, "epoch_700")
-            os.mkdir(path_4)
-            self.assertEqual(
-                get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2
-            )
-
-    @skip_if_not_distributed
-    def test_latest_checkpoint_path_distributed(self) -> None:
-        spawn_multi_process(
-            2,
-            "gloo",
-            self._latest_checkpoint_path_distributed,
-        )
-
-    @staticmethod
-    def _latest_checkpoint_path_distributed() -> None:
-        tc = unittest.TestCase()
-        is_rank0 = get_global_rank() == 0
-
-        if is_rank0:
-            temp_dir = tempfile.mkdtemp()
-        else:
-            temp_dir = ""
-        tc.assertIsNone(get_latest_checkpoint_path(temp_dir))
-        if is_rank0:
-            shutil.rmtree(temp_dir)  # delete temp directory
-
-        if is_rank0:
-            temp_dir = tempfile.mkdtemp()
-            path_1 = os.path.join(temp_dir, "epoch_0_step_0")
-            os.mkdir(path_1)
-            CheckpointUtilsTest._create_snapshot_metadata(path_1)
-            path_2 = os.path.join(temp_dir, "epoch_0_step_100")
-            os.mkdir(path_2)
-            CheckpointUtilsTest._create_snapshot_metadata(path_2)
-
-            # Missing metadata file
-            path_3 = os.path.join(temp_dir, "epoch_1_step_100")
-            os.mkdir(path_3)
-
-            # Ill-formatted name
-            path_4 = os.path.join(temp_dir, "epoch_700")
-            os.mkdir(path_4)
-        else:
-            temp_dir = ""
-            path_2 = ""
-
-        pg = PGWrapper(dist.group.WORLD)
-        path_container = [path_2] if is_rank0 else [None]
-        pg.broadcast_object_list(path_container, 0)
-        expected_path = path_container[0]
-        tc.assertIsNotNone(expected_path)
-        tc.assertEqual(
-            get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
-        )
-
-        if is_rank0:
-            shutil.rmtree(temp_dir)  # delete temp directory
-
-    def test_best_checkpoint_path(self) -> None:
-        with tempfile.TemporaryDirectory() as temp_dir:
-            self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min"))
-
-            # no checkpoint w/ metric value
-            path = os.path.join(temp_dir, "epoch_0_step_0")
-            os.mkdir(path)
-            self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min"))
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01")
-            os.mkdir(best_path)
-            self.assertEqual(
-                get_best_checkpoint_path(temp_dir, "val_loss", "min"),
-                best_path,
-            )
-            self.assertIsNone(
-                get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME),
-                None,
-            )
-            self._create_snapshot_metadata(best_path)
-            self.assertEqual(
-                get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME),
-                best_path,
-            )
-
-            # handle negative values
-            best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01")
-            os.mkdir(best_path_2)
-            self.assertEqual(
-                get_best_checkpoint_path(temp_dir, "val_loss", "min"),
-                best_path_2,
-            )
-
-            # handle "max" mode correctly
-            best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1")
-            os.mkdir(best_path_3)
-            self.assertEqual(
-                get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
-                best_path_3,
-            )
-
-            # handle different metric correctly
-            best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2")
-            os.mkdir(best_path_4)
-            self.assertEqual(
-                get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
-                best_path_3,
-            )
-            self.assertEqual(
-                get_best_checkpoint_path(
-                    temp_dir, metric_name="train_loss", mode="max"
-                ),
-                best_path_4,
-            )
-
-    def test_retrieve_checkpoint_dirpaths(self) -> None:
-        """
-        Tests retrieving checkpoint directories from a given root directory
-        """
-        with tempfile.TemporaryDirectory() as temp_dir:
-            paths = [
-                "epoch_0_step_10",
-                "epoch_1_step_10",
-                "epoch_2_step_10",
-                "epoch_0_step_5",
-                "epoch_0_step_6",
-                "epoch_0_step_3",
-            ]
-            for path in paths[:-1]:
-                os.mkdir(os.path.join(temp_dir, path))
-            # make last path a file instead of a directory
-            with open(os.path.join(temp_dir, paths[-1]), "w"):
-                pass
-
-            # compares set equality since order of returned dirpaths is not guaranteed
-            # in _retrieve_checkpoint_dirpaths
-            self.assertEqual(
-                set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
-                {os.path.join(temp_dir, path) for path in paths[:-1]},
-            )
-            self.assertEqual(
-                _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"),
-                [],
-            )
-
-            # check metadata file is correct filtered for
-            # by creating metadata for 3rd path in list
-            with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"):
-                pass
-
-            self.assertEqual(
-                set(
-                    _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata")
-                ),
-                {os.path.join(temp_dir, paths[2])},
-            )
-
-    def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
-        """
-        Tests retrieving checkpoint (w/ metrics) directories from a given root directory
-        """
-        with tempfile.TemporaryDirectory() as temp_dir:
-            paths = [
-                "epoch_0_step_10_val_loss=10",
-                "epoch_1_step_10_val_loss=5",
-                "epoch_2_step_10",
-                "epoch_0_step_5",
-                "epoch_0_step_6_train_loss=13",
-            ]
-            for path in paths:
-                os.mkdir(os.path.join(temp_dir, path))
-            # make last path a file instead of a directory
-            with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"):
-                pass
-
-            # compares set equality since order of returned dirpaths is not guaranteed
-            # in _retrieve_checkpoint_dirpaths
-            self.assertEqual(
-                set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
-                {os.path.join(temp_dir, path) for path in paths},
-            )
-            self.assertEqual(
-                set(
-                    _retrieve_checkpoint_dirpaths(
-                        temp_dir, metadata_fname=None, metric_name="val_loss"
-                    )
-                ),
-                {
-                    os.path.join(temp_dir, path) for path in paths[:2]
-                },  # since last path is a file
-            )
-            self.assertEqual(
-                _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"),
-                [],
-            )
-
-            # check metadata file is correct filtered for
-            # by creating metadata for 3rd path in list
-            with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"):
-                pass
-
-            self.assertEqual(
-                set(
-                    _retrieve_checkpoint_dirpaths(
-                        temp_dir, metadata_fname=".metadata", metric_name="val_loss"
-                    )
-                ),
-                {os.path.join(temp_dir, paths[1])},
-            )
-
-    @skip_if_not_distributed
-    def test_distributed_get_checkpoint_dirpaths(self) -> None:
-        spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths)
-
-    @staticmethod
-    def _distributed_get_checkpoint_dirpaths() -> None:
-        """
-        Tests that existing checkpoint directories are read and
-        properly registered on all ranks
-        """
-
-        @rank_zero_read_and_broadcast
-        def create_tmp_dir() -> str:
-            return tempfile.mkdtemp()
-
-        init_from_env()
-
-        temp_dir = create_tmp_dir()
-        try:
-            path1 = os.path.join(temp_dir, "epoch_0_step_10")
-            path2 = os.path.join(temp_dir, "epoch_1_step_20")
-            if get_global_rank() == 0:
-                os.mkdir(path1)
-                os.mkdir(path2)
-            torch.distributed.barrier()
-
-            ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir)
-            tc = unittest.TestCase()
-            tc.assertEqual(set(ckpt_dirpaths), {path1, path2})
-
-            tc.assertEqual(
-                get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), []
-            )
-        finally:
-            if get_global_rank() == 0:
-                shutil.rmtree(temp_dir)  # delete temp directory
-
-    def test_get_checkpoint_dirpaths(self) -> None:
-        """
-        Tests that `get_checkpoint_dirpaths` returns
-        the sorted checkpoint directories correctly
-        """
-        with tempfile.TemporaryDirectory() as temp_dir:
-            path1 = os.path.join(temp_dir, "epoch_1_step_20")
-            path2 = os.path.join(temp_dir, "epoch_4_step_130")
-            path3 = os.path.join(temp_dir, "epoch_0_step_10")
-            os.mkdir(path1)
-            os.mkdir(path2)
-            os.mkdir(path3)
-
-            self.assertEqual(
-                set(get_checkpoint_dirpaths(temp_dir)),
-                {path1, path2, path3},
-            )
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01")
-            path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2")
-            path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12")
-            os.mkdir(path1)
-            os.mkdir(path2)
-            os.mkdir(path3)
-
-            self.assertEqual(
-                set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")),
-                {path1, path2, path3},
-            )
-
-        with tempfile.TemporaryDirectory() as temp_dir:
-            self.assertEqual(
-                get_checkpoint_dirpaths(temp_dir),
-                [],
-            )
-
-    def test_checkpoint_sorting_utils(self) -> None:
-        """
-        Tests the sort utilities
-        """
-        paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"]
-        self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]])
-
-        paths = [
-            "epoch_1_step_20_val_loss=0.09",
-            "epoch_4_step_130_val_loss=29",
-            "epoch_0_step_10_val_loss=10",
-        ]
-        self.assertEqual(
-            _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]]
-        )
-        self.assertEqual(
-            _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]]
-        )
-
-    def test_delete_checkpoint(self) -> None:
-        """
-        Tests removing checkpoint directories
-        """
-        app_state = {"module": nn.Linear(2, 2)}
-        with tempfile.TemporaryDirectory() as temp_dir:
-            dirpath = os.path.join(temp_dir, "checkpoint")
-            Snapshot.take(dirpath, app_state=app_state)
-            self.assertTrue(os.path.exists(dirpath))
-            # check that error is thrown if .snapshot_metadata is not found in the directory when deleting
-            os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
-            with self.assertRaisesRegex(
-                RuntimeError, f"{temp_dir} does not contain .snapshot_metadata"
-            ):
-                _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME)
-            _delete_checkpoint(dirpath)
-            self.assertFalse(os.path.exists(dirpath))
-
-    def test_metadata_exists(self) -> None:
-        app_state = {"module": nn.Linear(2, 2)}
-        with tempfile.TemporaryDirectory() as temp_dir:
-            dirpath = os.path.join(temp_dir, "checkpoint")
-            Snapshot.take(dirpath, app_state=app_state)
-
-            fs = get_filesystem(dirpath)
-            self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
-
-            os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
-            self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
 
     def test_get_app_state(self) -> None:
         my_unit = DummyTrainUnit(input_dim=2)
diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py
index 91778df115..2257e683c2 100644
--- a/tests/utils/test_checkpoint.py
+++ b/tests/utils/test_checkpoint.py
@@ -5,9 +5,40 @@
 # LICENSE file in the root directory of this source tree.
 
 # pyre-strict
+import os
+import shutil
+import tempfile
 import unittest
 
-from torchtnt.utils.checkpoint import CheckpointPath, MetricData
+import torch
+
+import torch.distributed as dist
+from torch import nn
+from torchsnapshot import Snapshot
+from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
+from torchtnt.utils import get_global_rank, init_from_env
+
+from torchtnt.utils.checkpoint import (
+    _delete_checkpoint,
+    _metadata_exists,
+    _retrieve_checkpoint_dirpaths,
+    _sort_by_metric_value,
+    _sort_by_recency,
+    CheckpointPath,
+    get_best_checkpoint_path,
+    get_checkpoint_dirpaths,
+    get_latest_checkpoint_path,
+    MetricData,
+)
+from torchtnt.utils.distributed import (
+    PGWrapper,
+    rank_zero_read_and_broadcast,
+    spawn_multi_process,
+)
+from torchtnt.utils.fsspec import get_filesystem
+from torchtnt.utils.test_utils import skip_if_not_distributed
+
+METADATA_FNAME: str = ".metadata"
 
 
 class CheckpointPathTest(unittest.TestCase):
@@ -141,3 +172,378 @@ def test_compare_by_optimality(self) -> None:
         self.assertFalse(smaller.more_optimal_than(larger, mode="max"))
         self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
         self.assertFalse(larger.more_optimal_than(smaller, mode="min"))
+
+
+class CheckpointUtilsTest(unittest.TestCase):
+    @staticmethod
+    def _create_snapshot_metadata(output_dir: str) -> None:
+        path = os.path.join(output_dir, METADATA_FNAME)
+        with open(path, "w"):
+            pass
+
+    def test_latest_checkpoint_path(self) -> None:
+        with tempfile.TemporaryDirectory() as temp_dir:
+            self.assertIsNone(get_latest_checkpoint_path(temp_dir))
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            latest_path = os.path.join(temp_dir, "epoch_0_step_0")
+            os.mkdir(latest_path)
+            self.assertEqual(
+                get_latest_checkpoint_path(temp_dir),
+                latest_path,
+            )
+            self.assertEqual(
+                get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
+                None,
+            )
+            self._create_snapshot_metadata(latest_path)
+            self.assertEqual(
+                get_latest_checkpoint_path(temp_dir, METADATA_FNAME),
+                latest_path,
+            )
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            path_1 = os.path.join(temp_dir, "epoch_0_step_0")
+            os.mkdir(path_1)
+            self._create_snapshot_metadata(path_1)
+            path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002")
+            os.mkdir(path_2)
+            self._create_snapshot_metadata(path_2)
+
+            # Missing metadata file
+            path_3 = os.path.join(temp_dir, "epoch_1_step_100")
+            os.mkdir(path_3)
+
+            # Ill-formatted name
+            path_4 = os.path.join(temp_dir, "epoch_700")
+            os.mkdir(path_4)
+            self.assertEqual(
+                get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2
+            )
+
+    @skip_if_not_distributed
+    def test_latest_checkpoint_path_distributed(self) -> None:
+        spawn_multi_process(
+            2,
+            "gloo",
+            self._latest_checkpoint_path_distributed,
+        )
+
+    @staticmethod
+    def _latest_checkpoint_path_distributed() -> None:
+        tc = unittest.TestCase()
+        is_rank0 = get_global_rank() == 0
+
+        if is_rank0:
+            temp_dir = tempfile.mkdtemp()
+        else:
+            temp_dir = ""
+        tc.assertIsNone(get_latest_checkpoint_path(temp_dir))
+        if is_rank0:
+            shutil.rmtree(temp_dir)  # delete temp directory
+
+        if is_rank0:
+            temp_dir = tempfile.mkdtemp()
+            path_1 = os.path.join(temp_dir, "epoch_0_step_0")
+            os.mkdir(path_1)
+            CheckpointUtilsTest._create_snapshot_metadata(path_1)
+            path_2 = os.path.join(temp_dir, "epoch_0_step_100")
+            os.mkdir(path_2)
+            CheckpointUtilsTest._create_snapshot_metadata(path_2)
+
+            # Missing metadata file
+            path_3 = os.path.join(temp_dir, "epoch_1_step_100")
+            os.mkdir(path_3)
+
+            # Ill-formatted name
+            path_4 = os.path.join(temp_dir, "epoch_700")
+            os.mkdir(path_4)
+        else:
+            temp_dir = ""
+            path_2 = ""
+
+        pg = PGWrapper(dist.group.WORLD)
+        path_container = [path_2] if is_rank0 else [None]
+        pg.broadcast_object_list(path_container, 0)
+        expected_path = path_container[0]
+        tc.assertIsNotNone(expected_path)
+        tc.assertEqual(
+            get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
+        )
+
+        if is_rank0:
+            shutil.rmtree(temp_dir)  # delete temp directory
+
+    def test_best_checkpoint_path(self) -> None:
+        with tempfile.TemporaryDirectory() as temp_dir:
+            self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min"))
+
+            # no checkpoint w/ metric value
+            path = os.path.join(temp_dir, "epoch_0_step_0")
+            os.mkdir(path)
+            self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min"))
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01")
+            os.mkdir(best_path)
+            self.assertEqual(
+                get_best_checkpoint_path(temp_dir, "val_loss", "min"),
+                best_path,
+            )
+            self.assertIsNone(
+                get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME),
+                None,
+            )
+            self._create_snapshot_metadata(best_path)
+            self.assertEqual(
+                get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME),
+                best_path,
+            )
+
+            # handle negative values
+            best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01")
+            os.mkdir(best_path_2)
+            self.assertEqual(
+                get_best_checkpoint_path(temp_dir, "val_loss", "min"),
+                best_path_2,
+            )
+
+            # handle "max" mode correctly
+            best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1")
+            os.mkdir(best_path_3)
+            self.assertEqual(
+                get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
+                best_path_3,
+            )
+
+            # handle different metric correctly
+            best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2")
+            os.mkdir(best_path_4)
+            self.assertEqual(
+                get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"),
+                best_path_3,
+            )
+            self.assertEqual(
+                get_best_checkpoint_path(
+                    temp_dir, metric_name="train_loss", mode="max"
+                ),
+                best_path_4,
+            )
+
+    def test_retrieve_checkpoint_dirpaths(self) -> None:
+        """
+        Tests retrieving checkpoint directories from a given root directory
+        """
+        with tempfile.TemporaryDirectory() as temp_dir:
+            paths = [
+                "epoch_0_step_10",
+                "epoch_1_step_10",
+                "epoch_2_step_10",
+                "epoch_0_step_5",
+                "epoch_0_step_6",
+                "epoch_0_step_3",
+            ]
+            for path in paths[:-1]:
+                os.mkdir(os.path.join(temp_dir, path))
+            # make last path a file instead of a directory
+            with open(os.path.join(temp_dir, paths[-1]), "w"):
+                pass
+
+            # compares set equality since order of returned dirpaths is not guaranteed
+            # in _retrieve_checkpoint_dirpaths
+            self.assertEqual(
+                set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
+                {os.path.join(temp_dir, path) for path in paths[:-1]},
+            )
+            self.assertEqual(
+                _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"),
+                [],
+            )
+
+            # check metadata file is correct filtered for
+            # by creating metadata for 3rd path in list
+            with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"):
+                pass
+
+            self.assertEqual(
+                set(
+                    _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata")
+                ),
+                {os.path.join(temp_dir, paths[2])},
+            )
+
+    def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
+        """
+        Tests retrieving checkpoint (w/ metrics) directories from a given root directory
+        """
+        with tempfile.TemporaryDirectory() as temp_dir:
+            paths = [
+                "epoch_0_step_10_val_loss=10",
+                "epoch_1_step_10_val_loss=5",
+                "epoch_2_step_10",
+                "epoch_0_step_5",
+                "epoch_0_step_6_train_loss=13",
+            ]
+            for path in paths:
+                os.mkdir(os.path.join(temp_dir, path))
+            # make last path a file instead of a directory
+            with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"):
+                pass
+
+            # compares set equality since order of returned dirpaths is not guaranteed
+            # in _retrieve_checkpoint_dirpaths
+            self.assertEqual(
+                set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)),
+                {os.path.join(temp_dir, path) for path in paths},
+            )
+            self.assertEqual(
+                set(
+                    _retrieve_checkpoint_dirpaths(
+                        temp_dir, metadata_fname=None, metric_name="val_loss"
+                    )
+                ),
+                {
+                    os.path.join(temp_dir, path) for path in paths[:2]
+                },  # since last path is a file
+            )
+            self.assertEqual(
+                _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"),
+                [],
+            )
+
+            # check metadata file is correct filtered for
+            # by creating metadata for 3rd path in list
+            with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"):
+                pass
+
+            self.assertEqual(
+                set(
+                    _retrieve_checkpoint_dirpaths(
+                        temp_dir, metadata_fname=".metadata", metric_name="val_loss"
+                    )
+                ),
+                {os.path.join(temp_dir, paths[1])},
+            )
+
+    @skip_if_not_distributed
+    def test_distributed_get_checkpoint_dirpaths(self) -> None:
+        spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths)
+
+    @staticmethod
+    def _distributed_get_checkpoint_dirpaths() -> None:
+        """
+        Tests that existing checkpoint directories are read and
+        properly registered on all ranks
+        """
+
+        @rank_zero_read_and_broadcast
+        def create_tmp_dir() -> str:
+            return tempfile.mkdtemp()
+
+        init_from_env()
+
+        temp_dir = create_tmp_dir()
+        try:
+            path1 = os.path.join(temp_dir, "epoch_0_step_10")
+            path2 = os.path.join(temp_dir, "epoch_1_step_20")
+            if get_global_rank() == 0:
+                os.mkdir(path1)
+                os.mkdir(path2)
+            torch.distributed.barrier()
+
+            ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir)
+            tc = unittest.TestCase()
+            tc.assertEqual(set(ckpt_dirpaths), {path1, path2})
+
+            tc.assertEqual(
+                get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), []
+            )
+        finally:
+            if get_global_rank() == 0:
+                shutil.rmtree(temp_dir)  # delete temp directory
+
+    def test_get_checkpoint_dirpaths(self) -> None:
+        """
+        Tests that `get_checkpoint_dirpaths` returns
+        the sorted checkpoint directories correctly
+        """
+        with tempfile.TemporaryDirectory() as temp_dir:
+            path1 = os.path.join(temp_dir, "epoch_1_step_20")
+            path2 = os.path.join(temp_dir, "epoch_4_step_130")
+            path3 = os.path.join(temp_dir, "epoch_0_step_10")
+            os.mkdir(path1)
+            os.mkdir(path2)
+            os.mkdir(path3)
+
+            self.assertEqual(
+                set(get_checkpoint_dirpaths(temp_dir)),
+                {path1, path2, path3},
+            )
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01")
+            path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2")
+            path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12")
+            os.mkdir(path1)
+            os.mkdir(path2)
+            os.mkdir(path3)
+
+            self.assertEqual(
+                set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")),
+                {path1, path2, path3},
+            )
+
+        with tempfile.TemporaryDirectory() as temp_dir:
+            self.assertEqual(
+                get_checkpoint_dirpaths(temp_dir),
+                [],
+            )
+
+    def test_checkpoint_sorting_utils(self) -> None:
+        """
+        Tests the sort utilities
+        """
+        paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"]
+        self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]])
+
+        paths = [
+            "epoch_1_step_20_val_loss=0.09",
+            "epoch_4_step_130_val_loss=29",
+            "epoch_0_step_10_val_loss=10",
+        ]
+        self.assertEqual(
+            _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]]
+        )
+        self.assertEqual(
+            _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]]
+        )
+
+    def test_delete_checkpoint(self) -> None:
+        """
+        Tests removing checkpoint directories
+        """
+        app_state = {"module": nn.Linear(2, 2)}
+        with tempfile.TemporaryDirectory() as temp_dir:
+            dirpath = os.path.join(temp_dir, "checkpoint")
+            Snapshot.take(dirpath, app_state=app_state)
+            self.assertTrue(os.path.exists(dirpath))
+            # check that error is thrown if .snapshot_metadata is not found in the directory when deleting
+            os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
+            with self.assertRaisesRegex(
+                RuntimeError, f"{temp_dir} does not contain .snapshot_metadata"
+            ):
+                _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME)
+            _delete_checkpoint(dirpath)
+            self.assertFalse(os.path.exists(dirpath))
+
+    def test_metadata_exists(self) -> None:
+        app_state = {"module": nn.Linear(2, 2)}
+        with tempfile.TemporaryDirectory() as temp_dir:
+            dirpath = os.path.join(temp_dir, "checkpoint")
+            Snapshot.take(dirpath, app_state=app_state)
+
+            fs = get_filesystem(dirpath)
+            self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
+
+            os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
+            self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
diff --git a/torchtnt/framework/callbacks/_checkpoint_utils.py b/torchtnt/framework/callbacks/_checkpoint_utils.py
index 087c15c15b..674eb18fe5 100644
--- a/torchtnt/framework/callbacks/_checkpoint_utils.py
+++ b/torchtnt/framework/callbacks/_checkpoint_utils.py
@@ -6,268 +6,16 @@
 
 # pyre-strict
 
-import logging
-import os
-import re
 
-from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar
-
-import fsspec
+from typing import Any, Dict
 
 from pyre_extensions import none_throws
-from torch import distributed as dist
 from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
 from torchtnt.framework.state import State
 from torchtnt.framework.unit import AppStateMixin
-from torchtnt.utils.distributed import rank_zero_read_and_broadcast
 
-from torchtnt.utils.fsspec import get_filesystem
 from torchtnt.utils.stateful import Stateful
 
-logger: logging.Logger = logging.getLogger(__name__)
-
-T = TypeVar("T")
-
-
-@rank_zero_read_and_broadcast
-def get_latest_checkpoint_path(
-    dirpath: str,
-    metadata_fname: Optional[str] = None,
-    process_group: Optional[dist.ProcessGroup] = None,
-) -> Optional[str]:
-    """
-    Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory.
-
-    Args:
-        dirpath: parent directory where checkpoints are saved.
-        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
-        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
-
-    Raises:
-        AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}.
-    """
-
-    return _latest_checkpoint_path(dirpath, metadata_fname)
-
-
-def _latest_checkpoint_path(
-    dirpath: str, metadata_fname: Optional[str]
-) -> Optional[str]:
-    candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
-
-    # Initialize variables to store the largest epoch and step numbers
-    largest_subdirectory = None
-    largest_epoch = -1
-    largest_step = -1
-
-    # Iterate through all files and directories in the specified directory
-    for candidate in candidate_dirpaths:
-        # Extract the epoch and step numbers from the directory name
-        dirname = os.path.basename(candidate)
-
-        # dirname will be of the format epoch_N_step_M
-        # where N is the epoch number and M is the step number as integers
-        split = dirname.split("_")
-        if len(split) < 4:
-            raise AssertionError(
-                f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})"
-            )
-
-        epoch_num, step_num = int(split[1]), int(split[3])
-        # Check if the current epoch and step numbers are larger than the largest ones found so far
-        if epoch_num > largest_epoch:
-            largest_epoch = epoch_num
-            largest_step = step_num
-            largest_subdirectory = dirname
-        elif largest_epoch == epoch_num and step_num > largest_step:
-            largest_step = step_num
-            largest_subdirectory = dirname
-
-    if largest_subdirectory is None:
-        return None
-
-    # Rejoin with the parent directory path and return the largest subdirectory
-    return os.path.join(dirpath, none_throws(largest_subdirectory))
-
-
-@rank_zero_read_and_broadcast
-def get_best_checkpoint_path(
-    dirpath: str,
-    metric_name: str,
-    mode: Literal["min", "max"],
-    metadata_fname: Optional[str] = None,
-    process_group: Optional[dist.ProcessGroup] = None,
-) -> Optional[str]:
-    """
-    Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory.
-
-    Args:
-        dirpath: parent directory where checkpoints are saved.
-        metric_name: Name of the metric to use to find the best checkpoint.
-        mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
-        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
-        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
-    """
-
-    dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
-    if len(dirpaths) == 0:
-        # no checkpoints found
-        return None
-
-    best_checkpoint_path = None
-    best_metric_value = float("inf") if mode == "min" else float("-inf")
-    for dirpath in dirpaths:
-        dirname = os.path.basename(dirpath)
-        metric_value = float(dirname.split("=")[-1])
-
-        if mode == "min":
-            if metric_value < best_metric_value:
-                best_metric_value = metric_value
-                best_checkpoint_path = dirpath
-        else:
-            if metric_value > best_metric_value:
-                best_metric_value = metric_value
-                best_checkpoint_path = dirpath
-
-    return best_checkpoint_path
-
-
-@rank_zero_read_and_broadcast
-def get_checkpoint_dirpaths(
-    dirpath: str,
-    metadata_fname: Optional[str] = None,
-    metric_name: Optional[str] = None,
-    process_group: Optional[dist.ProcessGroup] = None,
-) -> List[str]:
-    """
-    Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories.
-    The order of the checkpoints is not guarenteed.
-
-    Args:
-        dirpath: parent directory where checkpoints are saved.
-        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
-        metric_name: fetches all the checkpoint directories containing the metric name only.
-        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
-    """
-
-    return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
-
-
-def _sort_by_recency(dirpaths: List[str]) -> List[str]:
-    """
-    Sorts the given list of directories by oldest to newest.
-
-    Args:
-        dirpaths: A list of directory paths.
-
-    Returns:
-        A sorted list of directory paths, sorted by recency.
-    """
-
-    def sort_fn(path: str) -> Tuple[int, int]:
-        x = os.path.basename(path)
-        return (int(x.split("_")[1]), int(x.split("_")[3]))
-
-    return sorted(dirpaths, key=sort_fn)
-
-
-def _sort_by_metric_value(
-    dirpaths: List[str], mode: Literal["min", "max"]
-) -> List[str]:
-    """
-    Sorts the given list of directories by the metric values.
-
-    Args:
-        dirpaths: A list of directory paths.
-        mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order
-
-    Returns:
-        A sorted list of directory paths, sorted by the metric values.
-    """
-
-    def sort_metric_fn(path: str) -> float:
-        x = os.path.basename(path)
-        metric_val = float(x.split("=")[-1])
-        return metric_val
-
-    return sorted(
-        dirpaths,
-        key=sort_metric_fn,
-        # sort descending if min, placing worst metric at top of list
-        reverse=(mode == "min"),
-    )
-
-
-def _retrieve_checkpoint_dirpaths(
-    dirpath: str,
-    metadata_fname: Optional[str],
-    metric_name: Optional[str] = None,
-) -> List[str]:
-    """
-    Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
-
-    Args:
-        dirpath: parent directory where checkpoints are saved.
-        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
-        metric_name: Name of the metric that must exist in checkpoint name.
-    """
-
-    if dirpath[-1] == "/":
-        # removes trailing forward slash if present
-        # required for regex search to work
-        dirpath = dirpath[:-1]
-
-    fs = get_filesystem(dirpath)
-
-    if not fs.exists(dirpath):
-        logger.warning(f"Input dirpath doesn't exist: {dirpath}")
-        return []
-
-    contents = fs.ls(dirpath, detail=True)
-    contents = [item["name"] for item in contents if item["type"] == "directory"]
-    if len(contents) == 0:
-        logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}")
-        return []
-
-    # Define the regex pattern to match the directory names
-    pattern = rf"^{dirpath}/epoch_\d+_step_\d+"
-    if metric_name:
-        # inject metric name in regex search
-        pattern += rf"_{metric_name}="
-    snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern)
-    candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents))
-
-    if not metadata_fname:
-        # return early as we don't need to filter out any paths
-        return candidate_dirpaths
-
-    # Iterate through all files and directories in the specified directory
-    # and check if metedata is present or not
-    valid_ckpt_dirpaths = []
-    for candidate in candidate_dirpaths:
-        if not _metadata_exists(fs, candidate, metadata_fname):
-            logger.warning(
-                f"Snapshot metadata is missing from {candidate}! Skipping this path"
-            )
-            continue
-
-        valid_ckpt_dirpaths.append(candidate)
-
-    return valid_ckpt_dirpaths
-
-
-def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None:
-    fs = get_filesystem(dirpath)
-    if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname):
-        raise RuntimeError(f"{dirpath} does not contain {metadata_fname}")
-    fs.rm(dirpath, recursive=True)
-
-
-def _metadata_exists(
-    fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str
-) -> bool:
-    return fs.exists(os.path.join(dirpath, metadata_fname))
-
 
 # keys for use when checkpointing
 _TRAIN_PROGRESS_STATE_KEY = "train_progress"
diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py
index f5f56fdf25..ded257c412 100644
--- a/torchtnt/framework/callbacks/base_checkpointer.py
+++ b/torchtnt/framework/callbacks/base_checkpointer.py
@@ -16,7 +16,14 @@
 import torch.distributed as dist
 from pyre_extensions import none_throws
 from torchtnt.framework.callback import Callback
-from torchtnt.framework.callbacks._checkpoint_utils import (
+from torchtnt.framework.callbacks.checkpointer_types import (
+    BestCheckpointConfig,
+    RestoreOptions,
+)
+from torchtnt.framework.state import EntryPoint, State
+from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
+from torchtnt.framework.utils import get_timing_context
+from torchtnt.utils.checkpoint import (
     _delete_checkpoint,
     _metadata_exists,
     _sort_by_metric_value,
@@ -25,13 +32,6 @@
     get_checkpoint_dirpaths,
     get_latest_checkpoint_path,
 )
-from torchtnt.framework.callbacks.checkpointer_types import (
-    BestCheckpointConfig,
-    RestoreOptions,
-)
-from torchtnt.framework.state import EntryPoint, State
-from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
-from torchtnt.framework.utils import get_timing_context
 from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast
 from torchtnt.utils.fsspec import get_filesystem
 from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py
index c0ad4c3b8d..06cb6d33da 100644
--- a/torchtnt/utils/__init__.py
+++ b/torchtnt/utils/__init__.py
@@ -6,7 +6,13 @@
 
 # pyre-strict
 
-from .checkpoint import CheckpointPath, MetricData
+from .checkpoint import (
+    CheckpointPath,
+    get_best_checkpoint_path,
+    get_checkpoint_dirpaths,
+    get_latest_checkpoint_path,
+    MetricData,
+)
 from .device import (
     copy_data_to_device,
     CPUStats,
@@ -151,4 +157,7 @@
     "spawn_multi_process",
     "CheckpointPath",
     "MetricData",
+    "get_best_checkpoint_path",
+    "get_checkpoint_dirpaths",
+    "get_latest_checkpoint_path",
 ]
diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py
index 234464b6a7..d5ddc4b2f6 100644
--- a/torchtnt/utils/checkpoint.py
+++ b/torchtnt/utils/checkpoint.py
@@ -5,13 +5,20 @@
 # LICENSE file in the root directory of this source tree.
 
 # pyre-strict
+import logging
 import os
 import re
 from dataclasses import dataclass
 from functools import total_ordering
-from typing import Literal, Optional, Pattern
+from typing import List, Literal, Optional, Pattern, Tuple
 
+import fsspec
+import torch.distributed as dist
+from fsspec.core import url_to_fs
 from pyre_extensions import none_throws
+from torchtnt.utils.distributed import rank_zero_read_and_broadcast
+
+logger: logging.Logger = logging.getLogger(__name__)
 
 
 @dataclass
@@ -176,3 +183,243 @@ def __eq__(self, other: "CheckpointPath") -> bool:
 
     def __gt__(self, other: "CheckpointPath") -> bool:
         return self.newer_than(other)
+
+
+@rank_zero_read_and_broadcast
+def get_latest_checkpoint_path(
+    dirpath: str,
+    metadata_fname: Optional[str] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> Optional[str]:
+    """
+    Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory.
+
+    Args:
+        dirpath: parent directory where checkpoints are saved.
+        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
+        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
+
+    Raises:
+        AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}.
+    """
+
+    return _latest_checkpoint_path(dirpath, metadata_fname)
+
+
+def _latest_checkpoint_path(
+    dirpath: str, metadata_fname: Optional[str]
+) -> Optional[str]:
+    candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
+
+    # Initialize variables to store the largest epoch and step numbers
+    largest_subdirectory = None
+    largest_epoch = -1
+    largest_step = -1
+
+    # Iterate through all files and directories in the specified directory
+    for candidate in candidate_dirpaths:
+        # Extract the epoch and step numbers from the directory name
+        dirname = os.path.basename(candidate)
+
+        # dirname will be of the format epoch_N_step_M
+        # where N is the epoch number and M is the step number as integers
+        split = dirname.split("_")
+        if len(split) < 4:
+            raise AssertionError(
+                f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})"
+            )
+
+        epoch_num, step_num = int(split[1]), int(split[3])
+        # Check if the current epoch and step numbers are larger than the largest ones found so far
+        if epoch_num > largest_epoch:
+            largest_epoch = epoch_num
+            largest_step = step_num
+            largest_subdirectory = dirname
+        elif largest_epoch == epoch_num and step_num > largest_step:
+            largest_step = step_num
+            largest_subdirectory = dirname
+
+    if largest_subdirectory is None:
+        return None
+
+    # Rejoin with the parent directory path and return the largest subdirectory
+    return os.path.join(dirpath, none_throws(largest_subdirectory))
+
+
+@rank_zero_read_and_broadcast
+def get_best_checkpoint_path(
+    dirpath: str,
+    metric_name: str,
+    mode: Literal["min", "max"],
+    metadata_fname: Optional[str] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> Optional[str]:
+    """
+    Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory.
+
+    Args:
+        dirpath: parent directory where checkpoints are saved.
+        metric_name: Name of the metric to use to find the best checkpoint.
+        mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
+        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
+        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
+    """
+
+    dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
+    if len(dirpaths) == 0:
+        # no checkpoints found
+        return None
+
+    best_checkpoint_path = None
+    best_metric_value = float("inf") if mode == "min" else float("-inf")
+    for dirpath in dirpaths:
+        dirname = os.path.basename(dirpath)
+        metric_value = float(dirname.split("=")[-1])
+
+        if mode == "min":
+            if metric_value < best_metric_value:
+                best_metric_value = metric_value
+                best_checkpoint_path = dirpath
+        else:
+            if metric_value > best_metric_value:
+                best_metric_value = metric_value
+                best_checkpoint_path = dirpath
+
+    return best_checkpoint_path
+
+
+@rank_zero_read_and_broadcast
+def get_checkpoint_dirpaths(
+    dirpath: str,
+    metadata_fname: Optional[str] = None,
+    metric_name: Optional[str] = None,
+    process_group: Optional[dist.ProcessGroup] = None,
+) -> List[str]:
+    """
+    Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories.
+    The order of the checkpoints is not guarenteed.
+
+    Args:
+        dirpath: parent directory where checkpoints are saved.
+        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
+        metric_name: fetches all the checkpoint directories containing the metric name only.
+        process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
+    """
+
+    return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
+
+
+def _sort_by_recency(dirpaths: List[str]) -> List[str]:
+    """
+    Sorts the given list of directories by oldest to newest.
+
+    Args:
+        dirpaths: A list of directory paths.
+
+    Returns:
+        A sorted list of directory paths, sorted by recency.
+    """
+
+    def sort_fn(path: str) -> Tuple[int, int]:
+        x = os.path.basename(path)
+        return (int(x.split("_")[1]), int(x.split("_")[3]))
+
+    return sorted(dirpaths, key=sort_fn)
+
+
+def _sort_by_metric_value(
+    dirpaths: List[str], mode: Literal["min", "max"]
+) -> List[str]:
+    """
+    Sorts the given list of directories by the metric values.
+
+    Args:
+        dirpaths: A list of directory paths.
+        mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order
+
+    Returns:
+        A sorted list of directory paths, sorted by the metric values.
+    """
+
+    def sort_metric_fn(path: str) -> float:
+        x = os.path.basename(path)
+        metric_val = float(x.split("=")[-1])
+        return metric_val
+
+    return sorted(
+        dirpaths,
+        key=sort_metric_fn,
+        # sort descending if min, placing worst metric at top of list
+        reverse=(mode == "min"),
+    )
+
+
+def _retrieve_checkpoint_dirpaths(
+    dirpath: str,
+    metadata_fname: Optional[str],
+    metric_name: Optional[str] = None,
+) -> List[str]:
+    """
+    Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
+
+    Args:
+        dirpath: parent directory where checkpoints are saved.
+        metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
+        metric_name: Name of the metric that must exist in checkpoint name.
+    """
+
+    if dirpath[-1] == "/":
+        # removes trailing forward slash if present
+        # required for regex search to work
+        dirpath = dirpath[:-1]
+
+    fs, _ = url_to_fs(dirpath)
+
+    if not fs.exists(dirpath):
+        logger.warning(f"Input dirpath doesn't exist: {dirpath}")
+        return []
+
+    contents = fs.ls(dirpath, detail=True)
+    contents = [item["name"] for item in contents if item["type"] == "directory"]
+    if len(contents) == 0:
+        logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}")
+        return []
+
+    # Define the regex pattern to match the directory names
+    pattern = rf"^{dirpath}/epoch_\d+_step_\d+"
+    if metric_name:
+        # inject metric name in regex search
+        pattern += rf"_{metric_name}="
+    snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern)
+    candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents))
+
+    if not metadata_fname:
+        # return early as we don't need to filter out any paths
+        return candidate_dirpaths
+
+    # Iterate through all files and directories in the specified directory
+    # and check if metedata is present or not
+    valid_ckpt_dirpaths = []
+    for candidate in candidate_dirpaths:
+        if not _metadata_exists(fs, candidate, metadata_fname):
+            logger.warning(
+                f"Snapshot metadata is missing from {candidate}! Skipping this path"
+            )
+            continue
+
+        valid_ckpt_dirpaths.append(candidate)
+
+    return valid_ckpt_dirpaths
+
+
+def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None:
+    fs, _ = url_to_fs(dirpath)
+    if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname):
+        raise RuntimeError(f"{dirpath} does not contain {metadata_fname}")
+    fs.rm(dirpath, recursive=True)
+
+
+def _metadata_exists(
+    fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str
+) -> bool:
+    return fs.exists(os.path.join(dirpath, metadata_fname))