|
| 1 | +import os |
| 2 | +import time |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch.distributed.fsdp import FSDPModule |
| 6 | +from torch.distributed.checkpoint.state_dict import ( |
| 7 | + _init_optim_state, |
| 8 | + get_model_state_dict, |
| 9 | + get_optimizer_state_dict, |
| 10 | + set_model_state_dict, |
| 11 | + set_optimizer_state_dict, |
| 12 | + StateDictOptions, |
| 13 | +) |
| 14 | +from torch.distributed.tensor import distribute_tensor, DTensor |
| 15 | + |
| 16 | + |
| 17 | +MODEL_CHECKPOINT = "model_state_dict.pt" |
| 18 | +OPTIM_CHECKPOINT = "optim_state_dict.pt" |
| 19 | + |
| 20 | +def get_latest_checkpoint_folder(path): |
| 21 | + max_num = None |
| 22 | + if not os.path.exists(path): |
| 23 | + return max_num |
| 24 | + for name in os.listdir(path): |
| 25 | + folder_path = os.path.join(path, name) |
| 26 | + if os.path.isdir(folder_path): |
| 27 | + try: |
| 28 | + num = int(name) |
| 29 | + if max_num is None or num > max_num: |
| 30 | + max_num = num |
| 31 | + except ValueError: |
| 32 | + pass # Skip non-numeric folder names |
| 33 | + return max_num |
| 34 | + |
| 35 | + |
| 36 | +class Checkpointer: |
| 37 | + def __init__(self, folder: str, dcp_api: bool): |
| 38 | + self.folder = folder |
| 39 | + self.dcp_api = dcp_api |
| 40 | + self.last_training_time = get_latest_checkpoint_folder(f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}") |
| 41 | + |
| 42 | + def is_empty(self): |
| 43 | + return self.last_training_time is None |
| 44 | + |
| 45 | + def load_model(self, model: FSDPModule): |
| 46 | + last_model_checkpoint = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{self.last_training_time}/{MODEL_CHECKPOINT}" |
| 47 | + full_sd = torch.load(last_model_checkpoint, mmap=True, weights_only=True, map_location='cpu') |
| 48 | + if self.dcp_api: |
| 49 | + set_model_state_dict( |
| 50 | + model=model, |
| 51 | + model_state_dict=full_sd, |
| 52 | + options=StateDictOptions( |
| 53 | + full_state_dict=True, |
| 54 | + broadcast_from_rank0=True, |
| 55 | + ), |
| 56 | + ) |
| 57 | + return |
| 58 | + meta_sharded_sd = model.state_dict() |
| 59 | + sharded_sd = {} |
| 60 | + for param_name, full_tensor in full_sd.items(): |
| 61 | + sharded_meta_param = meta_sharded_sd.get(param_name) |
| 62 | + sharded_tensor = distribute_tensor( |
| 63 | + full_tensor, |
| 64 | + sharded_meta_param.device_mesh, |
| 65 | + sharded_meta_param.placements, |
| 66 | + ) |
| 67 | + sharded_sd[param_name] = nn.Parameter(sharded_tensor) |
| 68 | + # choose `assign=True` since we cannot call `copy_` on meta tensor |
| 69 | + model.load_state_dict(sharded_sd, strict=False, assign=True) |
| 70 | + |
| 71 | + def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer): |
| 72 | + last_optim_checkpoint = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{self.last_training_time}/{OPTIM_CHECKPOINT}" |
| 73 | + full_sd = torch.load(last_optim_checkpoint, mmap=True, weights_only=True, map_location='cpu') |
| 74 | + if self.dcp_api: |
| 75 | + set_optimizer_state_dict( |
| 76 | + model=model, |
| 77 | + optimizers=opt, |
| 78 | + optim_state_dict=full_sd, |
| 79 | + options=StateDictOptions( |
| 80 | + full_state_dict=True, |
| 81 | + broadcast_from_rank0=True, |
| 82 | + ) |
| 83 | + ) |
| 84 | + return |
| 85 | + PARAMS = "params" |
| 86 | + _init_optim_state(opt) |
| 87 | + param_groups = opt.state_dict()["param_groups"] |
| 88 | + state = opt.state_dict()["state"] |
| 89 | + |
| 90 | + full_param_groups = full_sd["param_groups"] |
| 91 | + full_state = full_sd["state"] |
| 92 | + |
| 93 | + for param_group, full_param_group in zip(param_groups, full_param_groups): |
| 94 | + for key, value in full_param_group.items(): |
| 95 | + if key == PARAMS: |
| 96 | + continue |
| 97 | + param_group[key] = value |
| 98 | + for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): |
| 99 | + if pid not in state: |
| 100 | + continue |
| 101 | + param_state = state[pid] |
| 102 | + full_param_state = full_state[full_pid] |
| 103 | + for attr, full_tensor in full_param_state.items(): |
| 104 | + sharded_tensor = param_state[attr] |
| 105 | + if isinstance(sharded_tensor, DTensor): |
| 106 | + # exp_avg is DTensor |
| 107 | + param_state[attr] = distribute_tensor( |
| 108 | + full_tensor, |
| 109 | + sharded_tensor.device_mesh, |
| 110 | + sharded_tensor.placements, |
| 111 | + ) |
| 112 | + else: |
| 113 | + # step is plain tensor |
| 114 | + param_state[attr] = full_tensor |
| 115 | + opt.load_state_dict( |
| 116 | + { |
| 117 | + "param_groups": param_groups, |
| 118 | + "state": state, |
| 119 | + } |
| 120 | + ) |
| 121 | + |
| 122 | + def _get_full_model_state_dict(self, model: FSDPModule): |
| 123 | + if self.dcp_api: |
| 124 | + return get_model_state_dict( |
| 125 | + model=model, |
| 126 | + options=StateDictOptions( |
| 127 | + full_state_dict=True, |
| 128 | + cpu_offload=True, |
| 129 | + ) |
| 130 | + ) |
| 131 | + |
| 132 | + sharded_sd = model.state_dict() |
| 133 | + cpu_state_dict = {} |
| 134 | + for param_name, sharded_param in sharded_sd.items(): |
| 135 | + full_param = sharded_param.full_tensor() |
| 136 | + if torch.distributed.get_rank() == 0: |
| 137 | + cpu_state_dict[param_name] = full_param.cpu() |
| 138 | + else: |
| 139 | + del full_param |
| 140 | + return cpu_state_dict |
| 141 | + |
| 142 | + def _get_full_optimizer_state_dict( |
| 143 | + self, |
| 144 | + model: FSDPModule, |
| 145 | + opt: torch.optim.Optimizer, |
| 146 | + ): |
| 147 | + if self.dcp_api: |
| 148 | + return get_optimizer_state_dict( |
| 149 | + model=model, |
| 150 | + optimizers=opt, |
| 151 | + options=StateDictOptions( |
| 152 | + full_state_dict=True, |
| 153 | + cpu_offload=True, |
| 154 | + ) |
| 155 | + ) |
| 156 | + is_rank_zero = (torch.distributed.get_rank() == 0) |
| 157 | + sharded_sd = opt.state_dict() |
| 158 | + sharded_state = sharded_sd["state"] |
| 159 | + full_state = {} |
| 160 | + for group_id, sharded_group in sharded_state.items(): |
| 161 | + group_state = {} |
| 162 | + for attr, sharded_tensor in sharded_group.items(): |
| 163 | + if isinstance(sharded_tensor, DTensor): |
| 164 | + # "exp_avg" in AdamW is `DTensor` |
| 165 | + full_tensor = sharded_tensor.full_tensor() |
| 166 | + else: |
| 167 | + # "step" in AdamW is plain tensor |
| 168 | + full_tensor = sharded_tensor |
| 169 | + if is_rank_zero: |
| 170 | + group_state[attr] = full_tensor.cpu() |
| 171 | + else: |
| 172 | + del full_tensor |
| 173 | + if is_rank_zero: |
| 174 | + full_state[group_id] = group_state |
| 175 | + else: |
| 176 | + del group_state |
| 177 | + if is_rank_zero: |
| 178 | + return { |
| 179 | + "param_groups": sharded_sd["param_groups"], |
| 180 | + "state": full_state, |
| 181 | + } |
| 182 | + else: |
| 183 | + return {} |
| 184 | + |
| 185 | + def save(self, model: FSDPModule, optim: torch.optim.Optimizer): |
| 186 | + model_state_dict = self._get_full_model_state_dict(model) |
| 187 | + optim_state_dict = self._get_full_optimizer_state_dict(model, optim) |
| 188 | + if torch.distributed.get_rank() == 0: |
| 189 | + new_training_time = int(time.time() * 1000) |
| 190 | + new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}" |
| 191 | + new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}" |
| 192 | + new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}" |
| 193 | + os.makedirs(new_checkpoint_folder, exist_ok=True) |
| 194 | + torch.save(model_state_dict, new_model_checkpoint) |
| 195 | + torch.save(optim_state_dict, new_optim_checkpoint) |
0 commit comments