Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 183 additions & 95 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,105 +613,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati

state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
if ckpt_type == CheckpointType.GLOBAL and ckpt_format == "torch_dist":
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# TODO Handle non-empty directories (e.g., after a crash during saving).
ensure_directory_exists(checkpoint_name, check_parent=False)
if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
save_strategy = checkpointing_context['save_strategy']
# Already saved once before - don't need to rerun sharding validation
validate_sharding_integrity = not args.ckpt_assume_constant_structure
else:
validate_sharding_integrity = True
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if args.async_save:
save_strategy.thread_count = args.dist_ckpt_workers
else:
# We don't allow per-rank parallel save for sync save
logger.warning('Per-rank parallel save is not supported for sync save. '
'Setting args.dist_ckpt_workers to 1')
save_strategy.thread_count = 1
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
logger.debug("Plugging in the read metadata from the load strategy...")
save_strategy.cached_global_metadata = cached_global_metadata
else:
logger.debug("Failed to plug in the read metadata from the load strategy...")

if args.ckpt_fully_parallel_save:
if args.ckpt_fully_parallel_save_process_group == 'dp':
process_group = mpu.get_data_parallel_group(with_context_parallel=True)
elif args.ckpt_fully_parallel_save_process_group == 'ep_dp':
process_group = mpu.get_expert_data_parallel_group()
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group,
args.ckpt_assume_constant_structure)
# Store save strategy for future checkpoint saves
if checkpointing_context is not None:
checkpointing_context['save_strategy'] = save_strategy
end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
async_sharded_save=args.async_save,
validate_access_integrity=validate_sharding_integrity,
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata))
# [ModelOpt]: save sharded modelopt_state
if has_nvidia_modelopt:
save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
async_save_request = _save_global_dist_checkpoint(
args=args,
model=model,
state_dict=state_dict,
sharded_sd_metadata=sharded_sd_metadata,
checkpoint_name=checkpoint_name,
checkpointing_context=checkpointing_context,
preprocess_common_state_dict_fn=preprocess_common_state_dict_fn,
rank=rank,
start_ckpt=start_ckpt,
)
elif ckpt_type == CheckpointType.GLOBAL and ckpt_format in ["torch_dcp", "fsdp_dtensor"]:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# TODO Handle non-empty directories (e.g., after a crash during saving).
ensure_directory_exists(checkpoint_name, check_parent=False)

if ckpt_format == "fsdp_dtensor":
state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0])

fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name)
torch.distributed.checkpoint.save(
async_save_request = _save_global_dcp_checkpoint(
args=args,
model=model,
state_dict=state_dict,
storage_writer=fs_storage_writer,
checkpoint_name=checkpoint_name,
ckpt_format=ckpt_format,
)
elif ckpt_type == CheckpointType.LOCAL:
async_save_request = _save_local_checkpoint(
args=args,
state_dict=state_dict,
sharded_sd_metadata=sharded_sd_metadata,
iteration=iteration,
checkpointing_context=checkpointing_context,
rank=rank,
start_ckpt=start_ckpt,
)
else:
# [ModelOpt]: Inject modelopt_state into state_dict
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
else:
save_modelopt_state(model, state_dict)

end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
if ckpt_type == CheckpointType.LOCAL:
try:
from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']:
# Note: Currently full reshardabilty is not supported when local checkpoints are used.
raise RuntimeError(
f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. "
"Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
)
algo = args.non_persistent_local_ckpt_algo
cached_metadata = None
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
cached_metadata = checkpointing_context['local_checkpoint_cache']
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
state_dict, algo=algo, cached_metadata=cached_metadata,
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
)
async_save_request = checkpointing_context['local_checkpoint_manager'].save(
state_dict_for_save, iteration, is_async=bool(args.async_save)
)
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
else:
assert ckpt_type == CheckpointType.LEGACY
# Save.
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
assert ckpt_type == CheckpointType.LEGACY
async_save_request = _save_legacy_checkpoint(
model=model,
state_dict=state_dict,
checkpoint_name=checkpoint_name,
rank=rank,
start_ckpt=start_ckpt,
)
start_misc = time()
if ckpt_type != CheckpointType.LOCAL:
if not args.async_save:
Expand Down Expand Up @@ -822,6 +761,155 @@ def wandb_finalize_fn():

ft_integration.on_checkpointing_end(is_async_finalization=False)


def _save_global_dist_checkpoint(
args,
model,
state_dict,
sharded_sd_metadata,
checkpoint_name,
checkpointing_context,
preprocess_common_state_dict_fn,
rank,
start_ckpt,
):
"""Save global distributed checkpoint in torch_dist format."""
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# TODO Handle non-empty directories (e.g., after a crash during saving).
ensure_directory_exists(checkpoint_name, check_parent=False)
if checkpointing_context is not None and 'save_strategy' in checkpointing_context:
save_strategy = checkpointing_context['save_strategy']
# Already saved once before - don't need to rerun sharding validation
validate_sharding_integrity = not args.ckpt_assume_constant_structure
else:
validate_sharding_integrity = True
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if args.async_save:
save_strategy.thread_count = args.dist_ckpt_workers
else:
# We don't allow per-rank parallel save for sync save
logger.warning('Per-rank parallel save is not supported for sync save. '
'Setting args.dist_ckpt_workers to 1')
save_strategy.thread_count = 1
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
logger.debug("Plugging in the read metadata from the load strategy...")
save_strategy.cached_global_metadata = cached_global_metadata
else:
logger.debug("Failed to plug in the read metadata from the load strategy...")

if args.ckpt_fully_parallel_save:
if args.ckpt_fully_parallel_save_process_group == 'dp':
process_group = mpu.get_data_parallel_group(with_context_parallel=True)
elif args.ckpt_fully_parallel_save_process_group == 'ep_dp':
process_group = mpu.get_expert_data_parallel_group()
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, process_group,
args.ckpt_assume_constant_structure)
# Store save strategy for future checkpoint saves
if checkpointing_context is not None:
checkpointing_context['save_strategy'] = save_strategy
end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
async_save_request = dist_checkpointing.save(state_dict, checkpoint_name, save_strategy,
async_sharded_save=args.async_save,
validate_access_integrity=validate_sharding_integrity,
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata))
# [ModelOpt]: save sharded modelopt_state
if has_nvidia_modelopt:
save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1))
return async_save_request


def _save_global_dcp_checkpoint(
args,
model,
state_dict,
checkpoint_name,
ckpt_format,
):
"""Save global checkpoint in torch_dcp or fsdp_dtensor format."""
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# TODO Handle non-empty directories (e.g., after a crash during saving).
ensure_directory_exists(checkpoint_name, check_parent=False)

if ckpt_format == "fsdp_dtensor":
state_dict = preprocess_fsdp_dtensor_state_dict(args, state_dict, model[0])

fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(checkpoint_name)
torch.distributed.checkpoint.save(
state_dict=state_dict,
storage_writer=fs_storage_writer,
)
return None


def _save_local_checkpoint(
args,
state_dict,
sharded_sd_metadata,
iteration,
checkpointing_context,
rank,
start_ckpt,
):
"""Save local (per-node) non-persistent checkpoint."""
# [ModelOpt]: Inject modelopt_state into state_dict
if has_nvidia_modelopt:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')

end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
try:
from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
if (sharded_sd_metadata or {}).get('distrib_optim_sharding_type') in ['fully_reshardable', 'dp_zero_gather_scatter']:
# Note: Currently full reshardabilty is not supported when local checkpoints are used.
raise RuntimeError(
f"Local checkpointing does not support optimizer sharding type '{sharded_sd_metadata['distrib_optim_sharding_type']}'. "
"Don't use '--dist-ckpt-optim-fully-reshardable' when saving local checkpoints."
)
algo = args.non_persistent_local_ckpt_algo
cached_metadata = None
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
cached_metadata = checkpointing_context['local_checkpoint_cache']
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
state_dict, algo=algo, cached_metadata=cached_metadata,
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
)
async_save_request = checkpointing_context['local_checkpoint_manager'].save(
state_dict_for_save, iteration, is_async=bool(args.async_save)
)
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
return async_save_request


def _save_legacy_checkpoint(
model,
state_dict,
checkpoint_name,
rank,
start_ckpt
):
"""Save legacy checkpoint using plain torch.save."""
# [ModelOpt]: Inject modelopt_state into state_dict
if has_nvidia_modelopt:
save_modelopt_state(model, state_dict)

end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")

# Save.
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
return None


@_disable_gc()
def _async_delete_checkpoint_impl(save_path, iteration_to_delete, log_progress=False, lower_priority=False,
cpu_priority=None, io_priority=None):
Expand Down
Loading