|
20 | 20 | )
|
21 | 21 | from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
|
22 | 22 | from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
|
23 |
| - |
| 23 | +from torch.nn.parallel import DistributedDataParallel |
24 | 24 | from torchtnt.framework.callbacks._checkpoint_utils import (
|
25 | 25 | _prepare_app_state_for_checkpoint,
|
26 | 26 | _prepare_app_state_for_restore,
|
|
41 | 41 | from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
|
42 | 42 | from torchtnt.utils.optimizer import init_optim_state
|
43 | 43 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
| 44 | + |
44 | 45 | from torchtnt.utils.stateful import MultiStateful, Stateful
|
45 | 46 |
|
46 | 47 |
|
|
62 | 63 | FileSystemWriter as Writer,
|
63 | 64 | )
|
64 | 65 |
|
| 66 | +# below code provides BC for PyTorch versions which don't include distributed state dict |
| 67 | +# TODO: remove below code once this path is not longer supported |
| 68 | +try: |
| 69 | + import torch.distributed.checkpoint.state_dict as dsd |
| 70 | + |
| 71 | + # pyre-ignore Incompatible variable type [9] |
| 72 | + get_model_state_dict = dsd.get_model_state_dict |
| 73 | + |
| 74 | + def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None: |
| 75 | + return dsd.set_model_state_dict(mod, state_dict) |
| 76 | + |
| 77 | +except ImportError: |
| 78 | + logger.warn( |
| 79 | + "torch.distributed.checkpoint.state_dict checkpoint is not available, " |
| 80 | + "falling back on defaults. Consider updating PyTorch, as this version " |
| 81 | + "will not be supported in the future." |
| 82 | + ) |
| 83 | + |
| 84 | + def get_model_state_dict(mod: torch.nn.Module) -> Dict[str, Any]: |
| 85 | + return mod.state_dict() |
| 86 | + |
| 87 | + def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None: |
| 88 | + return mod.load_state_dict(state_dict) |
| 89 | + |
| 90 | + |
| 91 | +class DSDModelWrapper(Stateful): |
| 92 | + """This wrapper converts state dicts to Distributed State Dicts, essentially generating |
| 93 | + state dicts as if they were created using single-device methods. This is useful for |
| 94 | + when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed |
| 95 | + settings. |
| 96 | +
|
| 97 | + """ |
| 98 | + |
| 99 | + def __init__(self, mod: torch.nn.Module) -> None: |
| 100 | + self.mod: torch.nn.Module = mod |
| 101 | + |
| 102 | + def state_dict(self) -> Dict[str, Any]: |
| 103 | + return get_model_state_dict(self.mod) |
| 104 | + |
| 105 | + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| 106 | + set_model_state_dict(self.mod, state_dict) |
| 107 | + |
65 | 108 |
|
66 | 109 | class DistributedCheckpointSaver(BaseCheckpointer):
|
67 | 110 | """
|
@@ -148,6 +191,11 @@ def _checkpoint_impl(
|
148 | 191 | curr_snapshot_wait = hook == "on_train_end"
|
149 | 192 |
|
150 | 193 | app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
|
| 194 | + |
| 195 | + for key, obj in app_state.items(): |
| 196 | + if isinstance(obj, DistributedDataParallel): |
| 197 | + app_state[key] = DSDModelWrapper(obj) |
| 198 | + |
151 | 199 | # TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
|
152 | 200 | if self._async_checkpoint:
|
153 | 201 | with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
|
@@ -315,14 +363,17 @@ def restore(
|
315 | 363 | )
|
316 | 364 |
|
317 | 365 | # necessary for loading optimizers since states are initialized lazy
|
318 |
| - for obj in app_state.values(): |
| 366 | + for key, obj in app_state.items(): |
319 | 367 | # sometimes optimizers are actually held in a wrapper which handles calling
|
320 | 368 | # state_dict and load_state_dict, sa is the case for
|
321 | 369 | # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
|
322 | 370 | optimizer = getattr(obj, "optimizer", obj)
|
323 | 371 | if isinstance(optimizer, torch.optim.Optimizer):
|
324 | 372 | init_optim_state(optimizer)
|
325 | 373 |
|
| 374 | + if isinstance(obj, DistributedDataParallel): |
| 375 | + app_state[key] = DSDModelWrapper(obj) |
| 376 | + |
326 | 377 | try:
|
327 | 378 | dcp.load(
|
328 | 379 | {"app_state": MultiStateful(app_state)},
|
|
0 commit comments