Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] zcc ema under non-pp when acc=1 #9941

Open
wants to merge 2 commits into
base: incubate/paddlenlp-fleety
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions paddlenlp/trainer/utils/flash_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,36 +211,34 @@ def ema_state_dict(self):
ema_state_dict[k] = tensor
ema_state_dict_master_weights = {}
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
t = self.ema_buffer._slice(
meta["start"] - self.master_min_offset, meta["end"] - self.master_min_offset
).clone()
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
t = self.ema_buffer._slice(s, e).clone()
t.get_tensor()._set_dims(meta["shape"])
t.name = meta["name"]
ema_state_dict_master_weights[k] = t
ema_state_dict["master_weights"] = ema_state_dict_master_weights
return ema_state_dict

def load_ema_state_dict(self, path):
with device_guard("cpu"):
logger.info(f"[FC EMA] load state dict from {path}")
state_dict = paddle.load(path)
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[FC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
def load_ema_state_dict(self, state_dict):
for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items():
logger.info(f"[FC EMA] load model weight key={k}")
start = tensor_meta["start"]
end = tensor_meta["end"]
if tensor_meta["buffer_index"] not in self.ema_buffer_model_params:
continue # non fp32 has no `self.ema_buffer_model_params`
if k in state_dict:
cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]]
tensor = state_dict[k].flatten()
cpu_buffer[start:end] = tensor

ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[FC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
self.ema_buffer[s:e] = ema_master[k]
logger.info("[FC EMA] done loading")
ema_master = state_dict["master_weights"]
for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items():
logger.info(f"[FC EMA] load optimizer weight key={k}")
s = meta["start"] - self.master_min_offset
e = meta["end"] - self.master_min_offset
if k in ema_master: # state-dict is filtered
self.ema_buffer[s:e] = ema_master[k].flatten()


class ParamFusionStorageHelper:
Expand Down Expand Up @@ -407,10 +405,8 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
logger.info("Synced flash checkpoints.")

def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kwargs):
if not isinstance(model, PipelineLayer):
self.manager.flash_checkpoint_pipeline_hook(0)
# logger.info(
# f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}"
# f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}, type={type(model)}"
# )
if not control.should_save:
if args.flash_save_ema_coef is not None and state.global_step % self.flash_ema_interval == 0:
Expand All @@ -424,6 +420,8 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw
non_cached_objects = (lr_scheduler.state_dict(), copy.deepcopy(state))
self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects))
self.runtime_timer.stop()
if not isinstance(model, PipelineLayer):
self.manager.flash_checkpoint_pipeline_hook(0)

def _get_save_infos_based_on_steps(self, state, args, checkpoint_folder):
flash_checkpoint_dir = None
Expand Down Expand Up @@ -930,7 +928,15 @@ def run(self):
self.optimizer_fusion_storage_helper, self.param_fusion_storage_helper, self.ema_coef
)
if ema_ckpt_path is not None: # update ema if needed
self.flash_ema_processor.load_ema_state_dict(ema_ckpt_path)
logger.info(f"[FC EMA] load state dict from {ema_ckpt_path}")
with device_guard("cpu"):
state_dict = paddle.load(ema_ckpt_path)
if self.use_expert_parallel and self.dp_rank > 0:
state_dict = self._filter_moe_no_sync_optimizer_params(
self.model_meta_content, state_dict
)
self.flash_ema_processor.load_ema_state_dict(state_dict)
logger.info("[FC EMA] done loading")
ema_ckpt_path = None
elif task_type == FCTaskType.PREPARE:
start_time = time.time()
Expand Down
Loading