Skip to content

Commit

Permalink
zcc-ema fix load-state-dict when dp-moe
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Feb 25, 2025
1 parent 88b6b6a commit 82372a2
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions paddlenlp/trainer/utils/flash_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,27 +220,25 @@ def ema_state_dict(self):
ema_state_dict["master_weights"] = ema_state_dict_master_weights
return ema_state_dict

def load_ema_state_dict(self, path):
with device_guard("cpu"):
logger.info(f"[FC EMA] load state dict from {path}")
state_dict = paddle.load(path)
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[FC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
def load_ema_state_dict(self, state_dict):
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[FC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
if k in state_dict:
cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]]
tensor = state_dict[k].flatten()
cpu_buffer[start:end] = tensor

ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[FC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[FC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
if k in ema_master: # state-dict is filtered
self.ema_buffer[s:e] = ema_master[k]
logger.info("[FC EMA] done loading")


class ParamFusionStorageHelper:
Expand Down Expand Up @@ -930,7 +928,14 @@ def run(self):
self.optimizer_fusion_storage_helper, self.param_fusion_storage_helper, self.ema_coef
)
if ema_ckpt_path is not None: # update ema if needed
self.flash_ema_processor.load_ema_state_dict(ema_ckpt_path)
logger.info(f"[FC EMA] load state dict from {ema_ckpt_path}")
with device_guard("cpu"):
state_dict = paddle.load(ema_ckpt_path)
state_dict = self._filter_moe_no_sync_optimizer_params(
self.model_meta_content, state_dict
)
self.flash_ema_processor.load_ema_state_dict(state_dict)
logger.info("[FC EMA] done loading")
ema_ckpt_path = None
elif task_type == FCTaskType.PREPARE:
start_time = time.time()
Expand Down

0 comments on commit 82372a2

Please sign in to comment.