Skip to content

Commit 435b1cb

Browse files
LucasLLCfacebook-github-bot
authored andcommitted
wraps DDP models with DSD (#857)
Summary: Pull Request resolved: #857 Distributed State Dict is the current suggested way from PyTorch for ensuring parallelized models state dicts are compatible with save/loads in Single process or re-sharding scenarios. This diff updates dcp_saver to use DSD for DDP models. A good idea would be wrap all models in TNT with DSD, as this could replace some of the wrapper logic for FSDP and would guarantee future compat. N5551629 also contains a workaround for current DDP model saved before this diff, by manually removing the "module." prefix in the checkpoint. Differential Revision: D59234083
1 parent 58b6ea7 commit 435b1cb

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

torchtnt/framework/callbacks/dcp_saver.py

+53-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
2222
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
23-
23+
from torch.nn.parallel import DistributedDataParallel
2424
from torchtnt.framework.callbacks._checkpoint_utils import (
2525
_prepare_app_state_for_checkpoint,
2626
_prepare_app_state_for_restore,
@@ -41,6 +41,7 @@
4141
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
4242
from torchtnt.utils.optimizer import init_optim_state
4343
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
44+
4445
from torchtnt.utils.stateful import MultiStateful, Stateful
4546

4647

@@ -62,6 +63,48 @@
6263
FileSystemWriter as Writer,
6364
)
6465

66+
# below code provides BC for PyTorch versions which don't include distributed state dict
67+
# TODO: remove below code once this path is not longer supported
68+
try:
69+
import torch.distributed.checkpoint.state_dict as dsd
70+
71+
# pyre-ignore Incompatible variable type [9]
72+
get_model_state_dict = dsd.get_model_state_dict
73+
74+
def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
75+
return dsd.set_model_state_dict(mod, state_dict)
76+
77+
except ImportError:
78+
logger.warn(
79+
"torch.distributed.checkpoint.state_dict checkpoint is not available, "
80+
"falling back on defaults. Consider updating PyTorch, as this version "
81+
"will not be supported in the future."
82+
)
83+
84+
def get_model_state_dict(mod: torch.nn.Module) -> Dict[str, Any]:
85+
return mod.state_dict()
86+
87+
def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
88+
return mod.load_state_dict(state_dict)
89+
90+
91+
class DSDModelWrapper(Stateful):
92+
"""This wrapper converts state dicts to Distributed State Dicts, essentially generating
93+
state dicts as if they were created using single-device methods. This is useful for
94+
when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed
95+
settings.
96+
97+
"""
98+
99+
def __init__(self, mod: torch.nn.Module) -> None:
100+
self.mod: torch.nn.Module = mod
101+
102+
def state_dict(self) -> Dict[str, Any]:
103+
return get_model_state_dict(self.mod)
104+
105+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
106+
set_model_state_dict(self.mod, state_dict)
107+
65108

66109
class DistributedCheckpointSaver(BaseCheckpointer):
67110
"""
@@ -148,6 +191,11 @@ def _checkpoint_impl(
148191
curr_snapshot_wait = hook == "on_train_end"
149192

150193
app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
194+
195+
for key, obj in app_state.items():
196+
if isinstance(obj, DistributedDataParallel):
197+
app_state[key] = DSDModelWrapper(obj)
198+
151199
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
152200
if self._async_checkpoint:
153201
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
@@ -315,14 +363,17 @@ def restore(
315363
)
316364

317365
# necessary for loading optimizers since states are initialized lazy
318-
for obj in app_state.values():
366+
for key, obj in app_state.items():
319367
# sometimes optimizers are actually held in a wrapper which handles calling
320368
# state_dict and load_state_dict, sa is the case for
321369
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
322370
optimizer = getattr(obj, "optimizer", obj)
323371
if isinstance(optimizer, torch.optim.Optimizer):
324372
init_optim_state(optimizer)
325373

374+
if isinstance(obj, DistributedDataParallel):
375+
app_state[key] = DSDModelWrapper(obj)
376+
326377
try:
327378
dcp.load(
328379
{"app_state": MultiStateful(app_state)},

0 commit comments

Comments
 (0)