diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py
index 9b144f03b1..bc782c477d 100644
--- a/tests/utils/test_checkpoint.py
+++ b/tests/utils/test_checkpoint.py
@@ -6,6 +6,7 @@
 
 # pyre-strict
 import os
+import pickle
 import shutil
 import tempfile
 import unittest
@@ -173,6 +174,21 @@ def test_compare_by_optimality(self) -> None:
         self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
         self.assertFalse(larger.more_optimal_than(smaller, mode="min"))
 
+    def test_pickling(self) -> None:
+        for path in (
+            "foo/epoch_0_step_1",
+            "file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
+        ):
+            ckpt = CheckpointPath.from_str(path)
+
+            pickled = pickle.dumps(ckpt)
+
+            # Don't test equality because of custom protocol
+            self.assertTrue(path in str(pickled))
+
+            unpickled = pickle.loads(pickled)
+            self.assertEqual(unpickled, ckpt)
+
 
 class CheckpointUtilsTest(unittest.TestCase):
     @staticmethod
diff --git a/tests/utils/test_checkpoint_gpu.py b/tests/utils/test_checkpoint_gpu.py
new file mode 100644
index 0000000000..818aee4165
--- /dev/null
+++ b/tests/utils/test_checkpoint_gpu.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import torch.distributed as dist
+from torchtnt.utils import init_from_env
+from torchtnt.utils.checkpoint import get_checkpoint_dirpaths
+from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
+from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
+
+
+class TestCheckpointUtilsGPU(unittest.TestCase):
+
+    @skip_if_not_distributed
+    @skip_if_not_gpu
+    def test_get_checkpoint_dirpaths_distributed(self) -> None:
+        spawn_multi_process(
+            2,
+            "nccl",
+            self._test_get_checkpoint_dirpaths,
+        )
+
+    @staticmethod
+    def _test_get_checkpoint_dirpaths() -> None:
+        """
+        Tests retrieving checkpoint directories from a given root directory
+        using NCCL on GPUs with custom state for pickling.
+        """
+        init_from_env()
+        paths = [
+            "epoch_0_step_10",
+            "epoch_1_step_10_val_loss=10.5",
+            "epoch_2_step_10",
+            "epoch_0_step_5",
+            "epoch_0_step_6_acc=0.03",
+            "epoch_0_step_3",
+        ]
+
+        if get_global_rank() == 0:
+            temp_dir = tempfile.mkdtemp()
+            for path in paths:
+                os.mkdir(os.path.join(temp_dir, path))
+        else:
+            temp_dir = None
+
+        tc = unittest.TestCase()
+        # Only rank 0 will know about temp_dir
+        if get_global_rank() != 0:
+            tc.assertIsNone(temp_dir)
+
+        ckpt_dirpaths = get_checkpoint_dirpaths(
+            temp_dir, process_group=dist.group.WORLD
+        )
+
+        # Broadcast temp_dir to verify successful execution
+        temp_dir = [temp_dir] if get_global_rank() == 0 else [None]
+        dist.broadcast_object_list(temp_dir, src=0, group=dist.group.WORLD)
+        temp_dir = temp_dir[0]
+        tc.assertIsNotNone(temp_dir)
+
+        tc.assertEqual(
+            {str(x) for x in ckpt_dirpaths},
+            {os.path.join(temp_dir, path) for path in paths},
+        )
+
+        if get_global_rank() == 0:
+            shutil.rmtree(temp_dir)
diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py
index a677da03ff..5438afc830 100644
--- a/torchtnt/utils/checkpoint.py
+++ b/torchtnt/utils/checkpoint.py
@@ -198,6 +198,25 @@ def __eq__(self, other: "CheckpointPath") -> bool:
     def __gt__(self, other: "CheckpointPath") -> bool:
         return self.newer_than(other)
 
+    def __getstate__(self) -> str:
+        # Lightweight pickling to avoid broadcast errors
+        return self.path
+
+    def __setstate__(self, state: str) -> None:
+        # Match regex directly to avoid creating a new instance with `from_str`
+        path_match = self.PATH_REGEX.match(state)
+        assert path_match, f"Malformed checkpoint found when unpickling: {state}"
+
+        dirpath, epoch, step, metric_name, metric_value = path_match.groups()
+        self.dirpath = dirpath.rstrip("/")
+        self.epoch = int(epoch)
+        self.step = int(step)
+        self.metric_data = (
+            MetricData(name=metric_name, value=float(metric_value))
+            if metric_name and metric_value
+            else None
+        )
+
 
 @rank_zero_read_and_broadcast
 def get_latest_checkpoint_path(