From 94cc621deaacf56b0e4b0fce595806ee4d881073 Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Tue, 15 Apr 2025 17:15:34 -0700 Subject: [PATCH] Rank local checkpointing in DCP internal without collectives (#989) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/989 ### Context DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing (XLFormers style checkpointing) which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Differential Revision: D72390326 --- tests/framework/callbacks/test_dcp_saver.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index fc12243229..8e168c770b 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -991,7 +991,9 @@ class DummyStorageWriter(FileSystemWriter): def __init__(self, path: str) -> None: super().__init__(path) - def set_up_storage_writer(self, is_coordinator: bool) -> None: + def set_up_storage_writer( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: pass @@ -999,5 +1001,7 @@ class DummyStorageReader(FileSystemReader): def __init__(self, path: str) -> None: super().__init__(path) - def set_up_storage_writer(self, is_coordinator: bool) -> None: + def set_up_storage_reader( + self, is_coordinator: bool, *args: Any, **kwargs: Any + ) -> None: pass