diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index bebf574b965..967bc98f062 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -613,105 +613,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far if ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dist": - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - # TODO Handle non-empty directories (e.g., after a crash during saving). - ensure_directory_exists(checkpoint_name, check_parent=False) - if checkpointing_context is not None and 'save_strategy' in checkpointing_context: - save_strategy = checkpointing_context['save_strategy'] - # Already saved once before - don't need to rerun sharding validation - validate_sharding_integrity = not args.ckpt_assume_constant_structure - else: - validate_sharding_integrity = True - save_strategy = get_default_save_sharded_strategy(args.ckpt_format) - if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': - save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure - if args.async_save: - save_strategy.thread_count = args.dist_ckpt_workers - else: - # We don't allow per-rank parallel save for sync save - logger.warning('Per-rank parallel save is not supported for sync save. ' - 'Setting args.dist_ckpt_workers to 1') - save_strategy.thread_count = 1 - if checkpointing_context is not None and 'load_strategy' in checkpointing_context: - cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None) - if cached_global_metadata is not None: - logger.debug("Plugging in the read metadata from the load strategy...") - save_strategy.cached_global_metadata = cached_global_metadata - else: - logger.debug("Failed to plug in the read metadata from the load strategy...") - - if args.ckpt_fully_parallel_save: - if args.ckpt_fully_parallel_save_process_group == 'dp': - process_group = mpu.get_data_parallel_group(with_context_parallel=True) - elif args.ckpt_fully_parallel_save_process_group == 'ep_dp': - process_group = mpu.get_expert_data_parallel_group() - save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group, - args.ckpt_assume_constant_structure) - # Store save strategy for future checkpoint saves - if checkpointing_context is not None: - checkpointing_context['save_strategy'] = save_strategy - end_ckpt = time() - logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ") - async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy, - async_sharded_save=args.async_save, - validate_access_integrity=validate_sharding_integrity, - preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn, - content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata)) - # [ModelOpt]: save sharded modelopt_state - if has_nvidia_modelopt: - save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1)) + async_save_request = _save_global_dist_checkpoint( + args=args, + model=model, + state_dict=state_dict, + sharded_sd_metadata=sharded_sd_metadata, + checkpoint_name=checkpoint_name, + checkpointing_context=checkpointing_context, + preprocess_common_state_dict_fn=preprocess_common_state_dict_fn, + rank=rank, + start_ckpt=start_ckpt, + ) elif ckpt_type == CheckpointType.GLOBAL and ckpt_format in ["torch_dcp", "fsdp_dtensor"]: - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - # TODO Handle non-empty directories (e.g., after a crash during saving). - ensure_directory_exists(checkpoint_name, check_parent=False) - - if ckpt_format == "fsdp_dtensor": - state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0]) - - fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name) - torch.distributed.checkpoint.save( + async_save_request = _save_global_dcp_checkpoint( + args=args, + model=model, state_dict=state_dict, - storage_writer=fs_storage_writer, + checkpoint_name=checkpoint_name, + ckpt_format=ckpt_format, + ) + elif ckpt_type == CheckpointType.LOCAL: + async_save_request = _save_local_checkpoint( + args=args, + state_dict=state_dict, + sharded_sd_metadata=sharded_sd_metadata, + iteration=iteration, + checkpointing_context=checkpointing_context, + rank=rank, + start_ckpt=start_ckpt, ) else: - # [ModelOpt]: Inject modelopt_state into state_dict - if has_nvidia_modelopt: - if ckpt_type == CheckpointType.LOCAL: - print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.') - else: - save_modelopt_state(model, state_dict) - - end_ckpt = time() - logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ") - if ckpt_type == CheckpointType.LOCAL: - try: - from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict - except ModuleNotFoundError: - raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local " - "checkpointing but was not found. Please ensure it is installed.") - if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']: - # Note: Currently full reshardabilty is not supported when local checkpoints are used. - raise RuntimeError( - f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. " - "Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints." - ) - algo = args.non_persistent_local_ckpt_algo - cached_metadata = None - if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context: - cached_metadata = checkpointing_context['local_checkpoint_cache'] - state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict( - state_dict, algo=algo, cached_metadata=cached_metadata, - parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - async_save_request = checkpointing_context['local_checkpoint_manager'].save( - state_dict_for_save, iteration, is_async=bool(args.async_save) - ) - checkpointing_context['local_checkpoint_cache'] = cacheable_metadata - else: - assert ckpt_type == CheckpointType.LEGACY - # Save. - ensure_directory_exists(checkpoint_name) - torch.save(state_dict, checkpoint_name) + assert ckpt_type == CheckpointType.LEGACY + async_save_request = _save_legacy_checkpoint( + model=model, + state_dict=state_dict, + checkpoint_name=checkpoint_name, + rank=rank, + start_ckpt=start_ckpt, + ) start_misc = time() if ckpt_type != CheckpointType.LOCAL: if not args.async_save: @@ -822,6 +761,155 @@ def wandb_finalize_fn(): ft_integration.on_checkpointing_end(is_async_finalization=False) + +def _save_global_dist_checkpoint( + args, + model, + state_dict, + sharded_sd_metadata, + checkpoint_name, + checkpointing_context, + preprocess_common_state_dict_fn, + rank, + start_ckpt, +): + """Save global distributed checkpoint in torch_dist format.""" + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + # TODO Handle non-empty directories (e.g., after a crash during saving). + ensure_directory_exists(checkpoint_name, check_parent=False) + if checkpointing_context is not None and 'save_strategy' in checkpointing_context: + save_strategy = checkpointing_context['save_strategy'] + # Already saved once before - don't need to rerun sharding validation + validate_sharding_integrity = not args.ckpt_assume_constant_structure + else: + validate_sharding_integrity = True + save_strategy = get_default_save_sharded_strategy(args.ckpt_format) + if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': + save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure + if args.async_save: + save_strategy.thread_count = args.dist_ckpt_workers + else: + # We don't allow per-rank parallel save for sync save + logger.warning('Per-rank parallel save is not supported for sync save. ' + 'Setting args.dist_ckpt_workers to 1') + save_strategy.thread_count = 1 + if checkpointing_context is not None and 'load_strategy' in checkpointing_context: + cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None) + if cached_global_metadata is not None: + logger.debug("Plugging in the read metadata from the load strategy...") + save_strategy.cached_global_metadata = cached_global_metadata + else: + logger.debug("Failed to plug in the read metadata from the load strategy...") + + if args.ckpt_fully_parallel_save: + if args.ckpt_fully_parallel_save_process_group == 'dp': + process_group = mpu.get_data_parallel_group(with_context_parallel=True) + elif args.ckpt_fully_parallel_save_process_group == 'ep_dp': + process_group = mpu.get_expert_data_parallel_group() + save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group, + args.ckpt_assume_constant_structure) + # Store save strategy for future checkpoint saves + if checkpointing_context is not None: + checkpointing_context['save_strategy'] = save_strategy + end_ckpt = time() + logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ") + async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy, + async_sharded_save=args.async_save, + validate_access_integrity=validate_sharding_integrity, + preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn, + content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata)) + # [ModelOpt]: save sharded modelopt_state + if has_nvidia_modelopt: + save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1)) + return async_save_request + + +def _save_global_dcp_checkpoint( + args, + model, + state_dict, + checkpoint_name, + ckpt_format, +): + """Save global checkpoint in torch_dcp or fsdp_dtensor format.""" + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + # TODO Handle non-empty directories (e.g., after a crash during saving). + ensure_directory_exists(checkpoint_name, check_parent=False) + + if ckpt_format == "fsdp_dtensor": + state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0]) + + fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name) + torch.distributed.checkpoint.save( + state_dict=state_dict, + storage_writer=fs_storage_writer, + ) + return None + + +def _save_local_checkpoint( + args, + state_dict, + sharded_sd_metadata, + iteration, + checkpointing_context, + rank, + start_ckpt, +): + """Save local (per-node) non-persistent checkpoint.""" + # [ModelOpt]: Inject modelopt_state into state_dict + if has_nvidia_modelopt: + print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.') + + end_ckpt = time() + logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ") + try: + from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict + except ModuleNotFoundError: + raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local " + "checkpointing but was not found. Please ensure it is installed.") + if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']: + # Note: Currently full reshardabilty is not supported when local checkpoints are used. + raise RuntimeError( + f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. " + "Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints." + ) + algo = args.non_persistent_local_ckpt_algo + cached_metadata = None + if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context: + cached_metadata = checkpointing_context['local_checkpoint_cache'] + state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict( + state_dict, algo=algo, cached_metadata=cached_metadata, + parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + async_save_request = checkpointing_context['local_checkpoint_manager'].save( + state_dict_for_save, iteration, is_async=bool(args.async_save) + ) + checkpointing_context['local_checkpoint_cache'] = cacheable_metadata + return async_save_request + + +def _save_legacy_checkpoint( + model, + state_dict, + checkpoint_name, + rank, + start_ckpt +): + """Save legacy checkpoint using plain torch.save.""" + # [ModelOpt]: Inject modelopt_state into state_dict + if has_nvidia_modelopt: + save_modelopt_state(model, state_dict) + + end_ckpt = time() + logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ") + + # Save. + ensure_directory_exists(checkpoint_name) + torch.save(state_dict, checkpoint_name) + return None + + @_disable_gc() def _async_delete_checkpoint_impl(save_path, iteration_to_delete, log_progress=False, lower_priority=False, cpu_priority=None, io_priority=None):