You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When attempting to finetune (with no_load_optim=True) from slim weights, the deepspeed engine attempts to load files with the naming structure zero_pp_rank_X_mp_rank_Y_optim_states.pt. These files don't exist in the slim weights.
Here's a traceback:
File "/workdir/megatron/training.py", line 187, in pretrain model, optimizer, lr_scheduler = setup_model_and_optimizer(
File "/workdir/megatron/training.py", line 638, in setup_model_and_optimizer neox_args.iteration = load_checkpoint(
File "/workdir/megatron/checkpointing.py", line 247, in load_checkpoint checkpoint_name, state_dict = model.load_checkpoint(
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2783, in load_checkpoint success = self._load_zero_checkpoint(
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2962, in _load_zero_checkpoint zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 3056, in _get_all_zero_checkpoints return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 3028, in _get_all_zero_checkpoint_state_dicts _state = self.checkpoint_engine.load(
File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py", line 24, in load partition = torch.load(path, map_location=map_location)
File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 594, in load with _open_file_like(f, 'rb') as opened_file:
File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 230, in _open_file_like return _open_file(name_or_buffer, mode)
File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 211, in __init__ super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/shared_fs/20B_checkpoints/global_step150000/zero_pp_rank_0_mp_rank_03_optim_states.pt'
It looks like using universal checkpoints simply skips the zero optimizer loading, so it's perhaps not essential -- maybe a check for no_load_optim==True could be also be used to skip?
Screenshots
N/A
Environment (please complete the following information):
Describe the bug
When attempting to finetune (with
no_load_optim=True
) from slim weights, the deepspeed engine attempts to load files with the naming structurezero_pp_rank_X_mp_rank_Y_optim_states.pt
. These files don't exist in the slim weights.Here's a traceback:
To Reproduce
Run any training with
zero_optimization.stage > 0
,finetune=True
,no_load_optim=True
, andload
set to a location containing the slim weights from here: https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/.Expected behavior
Weights to be loaded successfully, and training to continue.
Proposed solution
Some relevant code locations:
It looks like using universal checkpoints simply skips the zero optimizer loading, so it's perhaps not essential -- maybe a check for
no_load_optim==True
could be also be used to skip?Screenshots
N/A
Environment (please complete the following information):
Config 1:
Config 2 (anonymized):
overwrite_values passed to NeoXArgs.from_ymls:
{"finetune": True, "no_load_optim": True}
Additional context
N/A
The text was updated successfully, but these errors were encountered: