@@ -613,105 +613,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
613613
614614 state_dict ['num_floating_point_operations_so_far' ] = num_floating_point_operations_so_far
615615 if ckpt_type == CheckpointType .GLOBAL and ckpt_format == "torch_dist" :
616- if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
617- # TODO Handle non-empty directories (e.g., after a crash during saving).
618- ensure_directory_exists (checkpoint_name , check_parent = False )
619- if checkpointing_context is not None and 'save_strategy' in checkpointing_context :
620- save_strategy = checkpointing_context ['save_strategy' ]
621- # Already saved once before - don't need to rerun sharding validation
622- validate_sharding_integrity = not args .ckpt_assume_constant_structure
623- else :
624- validate_sharding_integrity = True
625- save_strategy = get_default_save_sharded_strategy (args .ckpt_format )
626- if args .ckpt_assume_constant_structure and args .ckpt_format == 'torch_dist' :
627- save_strategy .use_cached_ckpt_structure = args .ckpt_assume_constant_structure
628- if args .async_save :
629- save_strategy .thread_count = args .dist_ckpt_workers
630- else :
631- # We don't allow per-rank parallel save for sync save
632- logger .warning ('Per-rank parallel save is not supported for sync save. '
633- 'Setting args.dist_ckpt_workers to 1' )
634- save_strategy .thread_count = 1
635- if checkpointing_context is not None and 'load_strategy' in checkpointing_context :
636- cached_global_metadata = getattr (checkpointing_context ['load_strategy' ], 'cached_global_metadata' , None )
637- if cached_global_metadata is not None :
638- logger .debug ("Plugging in the read metadata from the load strategy..." )
639- save_strategy .cached_global_metadata = cached_global_metadata
640- else :
641- logger .debug ("Failed to plug in the read metadata from the load strategy..." )
642-
643- if args .ckpt_fully_parallel_save :
644- if args .ckpt_fully_parallel_save_process_group == 'dp' :
645- process_group = mpu .get_data_parallel_group (with_context_parallel = True )
646- elif args .ckpt_fully_parallel_save_process_group == 'ep_dp' :
647- process_group = mpu .get_expert_data_parallel_group ()
648- save_strategy = FullyParallelSaveStrategyWrapper (save_strategy , process_group ,
649- args .ckpt_assume_constant_structure )
650- # Store save strategy for future checkpoint saves
651- if checkpointing_context is not None :
652- checkpointing_context ['save_strategy' ] = save_strategy
653- end_ckpt = time ()
654- logger .debug (f"rank: { rank } , takes { end_ckpt - start_ckpt } to prepare state dict for ckpt " )
655- async_save_request = dist_checkpointing .save (state_dict , checkpoint_name , save_strategy ,
656- async_sharded_save = args .async_save ,
657- validate_access_integrity = validate_sharding_integrity ,
658- preprocess_common_before_consistancy_check = preprocess_common_state_dict_fn ,
659- content_metadata = _clean_metadata_for_serialization (sharded_sd_metadata ))
660- # [ModelOpt]: save sharded modelopt_state
661- if has_nvidia_modelopt :
662- save_sharded_modelopt_state (model , checkpoint_name , (args .ckpt_format , 1 ))
616+ async_save_request = _save_global_dist_checkpoint (
617+ args = args ,
618+ model = model ,
619+ state_dict = state_dict ,
620+ sharded_sd_metadata = sharded_sd_metadata ,
621+ checkpoint_name = checkpoint_name ,
622+ checkpointing_context = checkpointing_context ,
623+ preprocess_common_state_dict_fn = preprocess_common_state_dict_fn ,
624+ rank = rank ,
625+ start_ckpt = start_ckpt ,
626+ )
663627 elif ckpt_type == CheckpointType .GLOBAL and ckpt_format in ["torch_dcp" , "fsdp_dtensor" ]:
664- if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
665- # TODO Handle non-empty directories (e.g., after a crash during saving).
666- ensure_directory_exists (checkpoint_name , check_parent = False )
667-
668- if ckpt_format == "fsdp_dtensor" :
669- state_dict = preprocess_fsdp_dtensor_state_dict (args , state_dict , model [0 ])
670-
671- fs_storage_writer = torch .distributed .checkpoint .FileSystemWriter (checkpoint_name )
672- torch .distributed .checkpoint .save (
628+ async_save_request = _save_global_dcp_checkpoint (
629+ args = args ,
630+ model = model ,
673631 state_dict = state_dict ,
674- storage_writer = fs_storage_writer ,
632+ checkpoint_name = checkpoint_name ,
633+ ckpt_format = ckpt_format ,
634+ )
635+ elif ckpt_type == CheckpointType .LOCAL :
636+ async_save_request = _save_local_checkpoint (
637+ args = args ,
638+ state_dict = state_dict ,
639+ sharded_sd_metadata = sharded_sd_metadata ,
640+ iteration = iteration ,
641+ checkpointing_context = checkpointing_context ,
642+ rank = rank ,
643+ start_ckpt = start_ckpt ,
675644 )
676645 else :
677- # [ModelOpt]: Inject modelopt_state into state_dict
678- if has_nvidia_modelopt :
679- if ckpt_type == CheckpointType .LOCAL :
680- print_rank_0 ('WARNING: Local checkpointing does not support nvidia_modelopt.' )
681- else :
682- save_modelopt_state (model , state_dict )
683-
684- end_ckpt = time ()
685- logger .debug (f"rank: { rank } , takes { end_ckpt - start_ckpt } to prepare state dict for ckpt " )
686- if ckpt_type == CheckpointType .LOCAL :
687- try :
688- from megatron .core .dist_checkpointing .tensor_aware_state_dict import MCoreTensorAwareStateDict
689- except ModuleNotFoundError :
690- raise RuntimeError ("The 'nvidia_resiliency_ext' module is required for local "
691- "checkpointing but was not found. Please ensure it is installed." )
692- if (sharded_sd_metadata or {}).get ('distrib_optim_sharding_type' ) in ['fully_reshardable' , 'dp_zero_gather_scatter' ]:
693- # Note: Currently full reshardabilty is not supported when local checkpoints are used.
694- raise RuntimeError (
695- f"Local checkpointing does not support optimizer sharding type '{ sharded_sd_metadata ['distrib_optim_sharding_type' ]} '. "
696- "Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
697- )
698- algo = args .non_persistent_local_ckpt_algo
699- cached_metadata = None
700- if args .ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context :
701- cached_metadata = checkpointing_context ['local_checkpoint_cache' ]
702- state_dict_for_save , cacheable_metadata = MCoreTensorAwareStateDict .from_state_dict (
703- state_dict , algo = algo , cached_metadata = cached_metadata ,
704- parallelization_group = mpu .get_data_parallel_group (with_context_parallel = True )
705- )
706- async_save_request = checkpointing_context ['local_checkpoint_manager' ].save (
707- state_dict_for_save , iteration , is_async = bool (args .async_save )
708- )
709- checkpointing_context ['local_checkpoint_cache' ] = cacheable_metadata
710- else :
711- assert ckpt_type == CheckpointType .LEGACY
712- # Save.
713- ensure_directory_exists (checkpoint_name )
714- torch .save (state_dict , checkpoint_name )
646+ assert ckpt_type == CheckpointType .LEGACY
647+ async_save_request = _save_legacy_checkpoint (
648+ model = model ,
649+ state_dict = state_dict ,
650+ checkpoint_name = checkpoint_name ,
651+ rank = rank ,
652+ start_ckpt = start_ckpt ,
653+ )
715654 start_misc = time ()
716655 if ckpt_type != CheckpointType .LOCAL :
717656 if not args .async_save :
@@ -822,6 +761,155 @@ def wandb_finalize_fn():
822761
823762 ft_integration .on_checkpointing_end (is_async_finalization = False )
824763
764+
765+ def _save_global_dist_checkpoint (
766+ args ,
767+ model ,
768+ state_dict ,
769+ sharded_sd_metadata ,
770+ checkpoint_name ,
771+ checkpointing_context ,
772+ preprocess_common_state_dict_fn ,
773+ rank ,
774+ start_ckpt ,
775+ ):
776+ """Save global distributed checkpoint in torch_dist format."""
777+ if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
778+ # TODO Handle non-empty directories (e.g., after a crash during saving).
779+ ensure_directory_exists (checkpoint_name , check_parent = False )
780+ if checkpointing_context is not None and 'save_strategy' in checkpointing_context :
781+ save_strategy = checkpointing_context ['save_strategy' ]
782+ # Already saved once before - don't need to rerun sharding validation
783+ validate_sharding_integrity = not args .ckpt_assume_constant_structure
784+ else :
785+ validate_sharding_integrity = True
786+ save_strategy = get_default_save_sharded_strategy (args .ckpt_format )
787+ if args .ckpt_assume_constant_structure and args .ckpt_format == 'torch_dist' :
788+ save_strategy .use_cached_ckpt_structure = args .ckpt_assume_constant_structure
789+ if args .async_save :
790+ save_strategy .thread_count = args .dist_ckpt_workers
791+ else :
792+ # We don't allow per-rank parallel save for sync save
793+ logger .warning ('Per-rank parallel save is not supported for sync save. '
794+ 'Setting args.dist_ckpt_workers to 1' )
795+ save_strategy .thread_count = 1
796+ if checkpointing_context is not None and 'load_strategy' in checkpointing_context :
797+ cached_global_metadata = getattr (checkpointing_context ['load_strategy' ], 'cached_global_metadata' , None )
798+ if cached_global_metadata is not None :
799+ logger .debug ("Plugging in the read metadata from the load strategy..." )
800+ save_strategy .cached_global_metadata = cached_global_metadata
801+ else :
802+ logger .debug ("Failed to plug in the read metadata from the load strategy..." )
803+
804+ if args .ckpt_fully_parallel_save :
805+ if args .ckpt_fully_parallel_save_process_group == 'dp' :
806+ process_group = mpu .get_data_parallel_group (with_context_parallel = True )
807+ elif args .ckpt_fully_parallel_save_process_group == 'ep_dp' :
808+ process_group = mpu .get_expert_data_parallel_group ()
809+ save_strategy = FullyParallelSaveStrategyWrapper (save_strategy , process_group ,
810+ args .ckpt_assume_constant_structure )
811+ # Store save strategy for future checkpoint saves
812+ if checkpointing_context is not None :
813+ checkpointing_context ['save_strategy' ] = save_strategy
814+ end_ckpt = time ()
815+ logger .debug (f"rank: { rank } , takes { end_ckpt - start_ckpt } to prepare state dict for ckpt " )
816+ async_save_request = dist_checkpointing .save (state_dict , checkpoint_name , save_strategy ,
817+ async_sharded_save = args .async_save ,
818+ validate_access_integrity = validate_sharding_integrity ,
819+ preprocess_common_before_consistancy_check = preprocess_common_state_dict_fn ,
820+ content_metadata = _clean_metadata_for_serialization (sharded_sd_metadata ))
821+ # [ModelOpt]: save sharded modelopt_state
822+ if has_nvidia_modelopt :
823+ save_sharded_modelopt_state (model , checkpoint_name , (args .ckpt_format , 1 ))
824+ return async_save_request
825+
826+
827+ def _save_global_dcp_checkpoint (
828+ args ,
829+ model ,
830+ state_dict ,
831+ checkpoint_name ,
832+ ckpt_format ,
833+ ):
834+ """Save global checkpoint in torch_dcp or fsdp_dtensor format."""
835+ if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
836+ # TODO Handle non-empty directories (e.g., after a crash during saving).
837+ ensure_directory_exists (checkpoint_name , check_parent = False )
838+
839+ if ckpt_format == "fsdp_dtensor" :
840+ state_dict = preprocess_fsdp_dtensor_state_dict (args , state_dict , model [0 ])
841+
842+ fs_storage_writer = torch .distributed .checkpoint .FileSystemWriter (checkpoint_name )
843+ torch .distributed .checkpoint .save (
844+ state_dict = state_dict ,
845+ storage_writer = fs_storage_writer ,
846+ )
847+ return None
848+
849+
850+ def _save_local_checkpoint (
851+ args ,
852+ state_dict ,
853+ sharded_sd_metadata ,
854+ iteration ,
855+ checkpointing_context ,
856+ rank ,
857+ start_ckpt ,
858+ ):
859+ """Save local (per-node) non-persistent checkpoint."""
860+ # [ModelOpt]: Inject modelopt_state into state_dict
861+ if has_nvidia_modelopt :
862+ print_rank_0 ('WARNING: Local checkpointing does not support nvidia_modelopt.' )
863+
864+ end_ckpt = time ()
865+ logger .debug (f"rank: { rank } , takes { end_ckpt - start_ckpt } to prepare state dict for ckpt " )
866+ try :
867+ from megatron .core .dist_checkpointing .tensor_aware_state_dict import MCoreTensorAwareStateDict
868+ except ModuleNotFoundError :
869+ raise RuntimeError ("The 'nvidia_resiliency_ext' module is required for local "
870+ "checkpointing but was not found. Please ensure it is installed." )
871+ if (sharded_sd_metadata or {}).get ('distrib_optim_sharding_type' ) in ['fully_reshardable' , 'dp_zero_gather_scatter' ]:
872+ # Note: Currently full reshardabilty is not supported when local checkpoints are used.
873+ raise RuntimeError (
874+ f"Local checkpointing does not support optimizer sharding type '{ sharded_sd_metadata ['distrib_optim_sharding_type' ]} '. "
875+ "Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
876+ )
877+ algo = args .non_persistent_local_ckpt_algo
878+ cached_metadata = None
879+ if args .ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context :
880+ cached_metadata = checkpointing_context ['local_checkpoint_cache' ]
881+ state_dict_for_save , cacheable_metadata = MCoreTensorAwareStateDict .from_state_dict (
882+ state_dict , algo = algo , cached_metadata = cached_metadata ,
883+ parallelization_group = mpu .get_data_parallel_group (with_context_parallel = True )
884+ )
885+ async_save_request = checkpointing_context ['local_checkpoint_manager' ].save (
886+ state_dict_for_save , iteration , is_async = bool (args .async_save )
887+ )
888+ checkpointing_context ['local_checkpoint_cache' ] = cacheable_metadata
889+ return async_save_request
890+
891+
892+ def _save_legacy_checkpoint (
893+ model ,
894+ state_dict ,
895+ checkpoint_name ,
896+ rank ,
897+ start_ckpt
898+ ):
899+ """Save legacy checkpoint using plain torch.save."""
900+ # [ModelOpt]: Inject modelopt_state into state_dict
901+ if has_nvidia_modelopt :
902+ save_modelopt_state (model , state_dict )
903+
904+ end_ckpt = time ()
905+ logger .debug (f"rank: { rank } , takes { end_ckpt - start_ckpt } to prepare state dict for ckpt " )
906+
907+ # Save.
908+ ensure_directory_exists (checkpoint_name )
909+ torch .save (state_dict , checkpoint_name )
910+ return None
911+
912+
825913@_disable_gc ()
826914def _async_delete_checkpoint_impl (save_path , iteration_to_delete , log_progress = False , lower_priority = False ,
827915 cpu_priority = None , io_priority = None ):
0 commit comments