Skip to content

Commit 680db8e

Browse files
Anmol202005maanug-nv
authored andcommitted
Separate save_checkpoint into per-type execution paths
1 parent c5f9dd3 commit 680db8e

1 file changed

Lines changed: 183 additions & 95 deletions

File tree

megatron/training/checkpointing.py

Lines changed: 183 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -613,105 +613,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
613613

614614
state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
615615
if ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dist":
616-
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
617-
# TODO Handle non-empty directories (e.g., after a crash during saving).
618-
ensure_directory_exists(checkpoint_name, check_parent=False)
619-
if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
620-
save_strategy = checkpointing_context['save_strategy']
621-
# Already saved once before - don't need to rerun sharding validation
622-
validate_sharding_integrity = not args.ckpt_assume_constant_structure
623-
else:
624-
validate_sharding_integrity = True
625-
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
626-
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
627-
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
628-
if args.async_save:
629-
save_strategy.thread_count = args.dist_ckpt_workers
630-
else:
631-
# We don't allow per-rank parallel save for sync save
632-
logger.warning('Per-rank parallel save is not supported for sync save. '
633-
'Setting args.dist_ckpt_workers to 1')
634-
save_strategy.thread_count = 1
635-
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
636-
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
637-
if cached_global_metadata is not None:
638-
logger.debug("Plugging in the read metadata from the load strategy...")
639-
save_strategy.cached_global_metadata = cached_global_metadata
640-
else:
641-
logger.debug("Failed to plug in the read metadata from the load strategy...")
642-
643-
if args.ckpt_fully_parallel_save:
644-
if args.ckpt_fully_parallel_save_process_group == 'dp':
645-
process_group = mpu.get_data_parallel_group(with_context_parallel=True)
646-
elif args.ckpt_fully_parallel_save_process_group == 'ep_dp':
647-
process_group = mpu.get_expert_data_parallel_group()
648-
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group,
649-
args.ckpt_assume_constant_structure)
650-
# Store save strategy for future checkpoint saves
651-
if checkpointing_context is not None:
652-
checkpointing_context['save_strategy'] = save_strategy
653-
end_ckpt = time()
654-
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
655-
async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
656-
async_sharded_save=args.async_save,
657-
validate_access_integrity=validate_sharding_integrity,
658-
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
659-
content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata))
660-
# [ModelOpt]: save sharded modelopt_state
661-
if has_nvidia_modelopt:
662-
save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
616+
async_save_request = _save_global_dist_checkpoint(
617+
args=args,
618+
model=model,
619+
state_dict=state_dict,
620+
sharded_sd_metadata=sharded_sd_metadata,
621+
checkpoint_name=checkpoint_name,
622+
checkpointing_context=checkpointing_context,
623+
preprocess_common_state_dict_fn=preprocess_common_state_dict_fn,
624+
rank=rank,
625+
start_ckpt=start_ckpt,
626+
)
663627
elif ckpt_type == CheckpointType.GLOBAL and ckpt_format in ["torch_dcp", "fsdp_dtensor"]:
664-
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
665-
# TODO Handle non-empty directories (e.g., after a crash during saving).
666-
ensure_directory_exists(checkpoint_name, check_parent=False)
667-
668-
if ckpt_format == "fsdp_dtensor":
669-
state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0])
670-
671-
fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name)
672-
torch.distributed.checkpoint.save(
628+
async_save_request = _save_global_dcp_checkpoint(
629+
args=args,
630+
model=model,
673631
state_dict=state_dict,
674-
storage_writer=fs_storage_writer,
632+
checkpoint_name=checkpoint_name,
633+
ckpt_format=ckpt_format,
634+
)
635+
elif ckpt_type == CheckpointType.LOCAL:
636+
async_save_request = _save_local_checkpoint(
637+
args=args,
638+
state_dict=state_dict,
639+
sharded_sd_metadata=sharded_sd_metadata,
640+
iteration=iteration,
641+
checkpointing_context=checkpointing_context,
642+
rank=rank,
643+
start_ckpt=start_ckpt,
675644
)
676645
else:
677-
# [ModelOpt]: Inject modelopt_state into state_dict
678-
if has_nvidia_modelopt:
679-
if ckpt_type == CheckpointType.LOCAL:
680-
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
681-
else:
682-
save_modelopt_state(model, state_dict)
683-
684-
end_ckpt = time()
685-
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
686-
if ckpt_type == CheckpointType.LOCAL:
687-
try:
688-
from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
689-
except ModuleNotFoundError:
690-
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
691-
"checkpointing but was not found. Please ensure it is installed.")
692-
if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']:
693-
# Note: Currently full reshardabilty is not supported when local checkpoints are used.
694-
raise RuntimeError(
695-
f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. "
696-
"Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
697-
)
698-
algo = args.non_persistent_local_ckpt_algo
699-
cached_metadata = None
700-
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
701-
cached_metadata = checkpointing_context['local_checkpoint_cache']
702-
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
703-
state_dict, algo=algo, cached_metadata=cached_metadata,
704-
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
705-
)
706-
async_save_request = checkpointing_context['local_checkpoint_manager'].save(
707-
state_dict_for_save, iteration, is_async=bool(args.async_save)
708-
)
709-
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
710-
else:
711-
assert ckpt_type == CheckpointType.LEGACY
712-
# Save.
713-
ensure_directory_exists(checkpoint_name)
714-
torch.save(state_dict, checkpoint_name)
646+
assert ckpt_type == CheckpointType.LEGACY
647+
async_save_request = _save_legacy_checkpoint(
648+
model=model,
649+
state_dict=state_dict,
650+
checkpoint_name=checkpoint_name,
651+
rank=rank,
652+
start_ckpt=start_ckpt,
653+
)
715654
start_misc = time()
716655
if ckpt_type != CheckpointType.LOCAL:
717656
if not args.async_save:
@@ -822,6 +761,155 @@ def wandb_finalize_fn():
822761

823762
ft_integration.on_checkpointing_end(is_async_finalization=False)
824763

764+
765+
def _save_global_dist_checkpoint(
766+
args,
767+
model,
768+
state_dict,
769+
sharded_sd_metadata,
770+
checkpoint_name,
771+
checkpointing_context,
772+
preprocess_common_state_dict_fn,
773+
rank,
774+
start_ckpt,
775+
):
776+
"""Save global distributed checkpoint in torch_dist format."""
777+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
778+
# TODO Handle non-empty directories (e.g., after a crash during saving).
779+
ensure_directory_exists(checkpoint_name, check_parent=False)
780+
if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
781+
save_strategy = checkpointing_context['save_strategy']
782+
# Already saved once before - don't need to rerun sharding validation
783+
validate_sharding_integrity = not args.ckpt_assume_constant_structure
784+
else:
785+
validate_sharding_integrity = True
786+
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
787+
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
788+
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
789+
if args.async_save:
790+
save_strategy.thread_count = args.dist_ckpt_workers
791+
else:
792+
# We don't allow per-rank parallel save for sync save
793+
logger.warning('Per-rank parallel save is not supported for sync save. '
794+
'Setting args.dist_ckpt_workers to 1')
795+
save_strategy.thread_count = 1
796+
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
797+
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
798+
if cached_global_metadata is not None:
799+
logger.debug("Plugging in the read metadata from the load strategy...")
800+
save_strategy.cached_global_metadata = cached_global_metadata
801+
else:
802+
logger.debug("Failed to plug in the read metadata from the load strategy...")
803+
804+
if args.ckpt_fully_parallel_save:
805+
if args.ckpt_fully_parallel_save_process_group == 'dp':
806+
process_group = mpu.get_data_parallel_group(with_context_parallel=True)
807+
elif args.ckpt_fully_parallel_save_process_group == 'ep_dp':
808+
process_group = mpu.get_expert_data_parallel_group()
809+
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group,
810+
args.ckpt_assume_constant_structure)
811+
# Store save strategy for future checkpoint saves
812+
if checkpointing_context is not None:
813+
checkpointing_context['save_strategy'] = save_strategy
814+
end_ckpt = time()
815+
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
816+
async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
817+
async_sharded_save=args.async_save,
818+
validate_access_integrity=validate_sharding_integrity,
819+
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
820+
content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata))
821+
# [ModelOpt]: save sharded modelopt_state
822+
if has_nvidia_modelopt:
823+
save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
824+
return async_save_request
825+
826+
827+
def _save_global_dcp_checkpoint(
828+
args,
829+
model,
830+
state_dict,
831+
checkpoint_name,
832+
ckpt_format,
833+
):
834+
"""Save global checkpoint in torch_dcp or fsdp_dtensor format."""
835+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
836+
# TODO Handle non-empty directories (e.g., after a crash during saving).
837+
ensure_directory_exists(checkpoint_name, check_parent=False)
838+
839+
if ckpt_format == "fsdp_dtensor":
840+
state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0])
841+
842+
fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name)
843+
torch.distributed.checkpoint.save(
844+
state_dict=state_dict,
845+
storage_writer=fs_storage_writer,
846+
)
847+
return None
848+
849+
850+
def _save_local_checkpoint(
851+
args,
852+
state_dict,
853+
sharded_sd_metadata,
854+
iteration,
855+
checkpointing_context,
856+
rank,
857+
start_ckpt,
858+
):
859+
"""Save local (per-node) non-persistent checkpoint."""
860+
# [ModelOpt]: Inject modelopt_state into state_dict
861+
if has_nvidia_modelopt:
862+
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
863+
864+
end_ckpt = time()
865+
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
866+
try:
867+
from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
868+
except ModuleNotFoundError:
869+
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
870+
"checkpointing but was not found. Please ensure it is installed.")
871+
if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']:
872+
# Note: Currently full reshardabilty is not supported when local checkpoints are used.
873+
raise RuntimeError(
874+
f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. "
875+
"Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
876+
)
877+
algo = args.non_persistent_local_ckpt_algo
878+
cached_metadata = None
879+
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
880+
cached_metadata = checkpointing_context['local_checkpoint_cache']
881+
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
882+
state_dict, algo=algo, cached_metadata=cached_metadata,
883+
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
884+
)
885+
async_save_request = checkpointing_context['local_checkpoint_manager'].save(
886+
state_dict_for_save, iteration, is_async=bool(args.async_save)
887+
)
888+
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
889+
return async_save_request
890+
891+
892+
def _save_legacy_checkpoint(
893+
model,
894+
state_dict,
895+
checkpoint_name,
896+
rank,
897+
start_ckpt
898+
):
899+
"""Save legacy checkpoint using plain torch.save."""
900+
# [ModelOpt]: Inject modelopt_state into state_dict
901+
if has_nvidia_modelopt:
902+
save_modelopt_state(model, state_dict)
903+
904+
end_ckpt = time()
905+
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
906+
907+
# Save.
908+
ensure_directory_exists(checkpoint_name)
909+
torch.save(state_dict, checkpoint_name)
910+
return None
911+
912+
825913
@_disable_gc()
826914
def _async_delete_checkpoint_impl(save_path, iteration_to_delete, log_progress=False, lower_priority=False,
827915
cpu_priority=None, io_priority=None):

0 commit comments

Comments
 (0)