Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a760b5e
Handle quantized CUDA tensors in async checkpoint writer
sbak5 Mar 12, 2026
1587314
Lint applied
sbak5 Mar 13, 2026
a1aad02
Import resiliency-ext async checkpointing
sbak5 Mar 13, 2026
65bb819
Revert "Handle quantized CUDA tensors in async checkpoint writer"
dimapihtar Mar 17, 2026
814f2fb
keep both nvrx & mcore async save strategies
dimapihtar Mar 17, 2026
9fc5757
rename variable
dimapihtar Mar 17, 2026
c2559d9
refactor get_async_strategy
dimapihtar Mar 17, 2026
1b9ad87
pass async strategy to save/load strategy
dimapihtar Mar 18, 2026
55b6a00
Merge branch 'main' into sbak/ckpt_migrate
dimapihtar Mar 18, 2026
4fa312f
fix imports
dimapihtar Mar 18, 2026
cb04e5a
properly pass async_strategy to async_save
dimapihtar Mar 18, 2026
247027f
properly pass async_strategy
dimapihtar Mar 18, 2026
c3673fd
properly pass async_strategy load
dimapihtar Mar 18, 2026
cc6fd79
remove extra code
dimapihtar Mar 18, 2026
9c69546
fix code style
dimapihtar Mar 18, 2026
c5adc81
add deprecation warning
dimapihtar Mar 18, 2026
79ed36e
fix style
dimapihtar Mar 18, 2026
33e2872
fix code style
dimapihtar Mar 18, 2026
ac936ed
set mcore async-strategy for some func tests
dimapihtar Mar 18, 2026
1063aef
update nvrx version
dimapihtar Mar 18, 2026
3a7af3e
add unit tests
dimapihtar Mar 18, 2026
bfeb20a
move warning
dimapihtar Mar 18, 2026
6ba7a29
revert changes
dimapihtar Mar 18, 2026
bf7f792
fix bug
dimapihtar Mar 18, 2026
0f3779d
update unit tests
dimapihtar Mar 18, 2026
0b71e78
revert changes
dimapihtar Mar 18, 2026
755483c
Revert "revert changes"
dimapihtar Mar 18, 2026
70ce1f5
Revert "update nvrx version"
dimapihtar Mar 18, 2026
d18bec2
Merge branch 'main' into sbak/ckpt_migrate
dimapihtar Mar 18, 2026
c3dabfb
update nvrx version
dimapihtar Mar 18, 2026
51c5501
fix style
dimapihtar Mar 18, 2026
d2cfee3
fix unit tests
dimapihtar Mar 18, 2026
c5c86f5
fix unit tests
dimapihtar Mar 18, 2026
0f81f06
avoid async_strategy param at serialization.load()
dimapihtar Mar 19, 2026
5ceef34
fix unit test
dimapihtar Mar 19, 2026
57c422c
fix unit test
dimapihtar Mar 19, 2026
9f690f0
move warning
dimapihtar Mar 19, 2026
9a36493
fix unit test
dimapihtar Mar 19, 2026
48c0d95
fix unit test
dimapihtar Mar 19, 2026
67703f2
Revert "fix unit test"
dimapihtar Mar 19, 2026
c8ed36c
disable async_save
dimapihtar Mar 19, 2026
8030764
fix unit test
dimapihtar Mar 19, 2026
a0591c8
fix typo
dimapihtar Mar 19, 2026
64f4280
Merge branch 'main' into sbak/ckpt_migrate
dimapihtar Mar 19, 2026
1f4979e
disable async_save
dimapihtar Mar 19, 2026
ddc37b5
fix unit test
dimapihtar Mar 19, 2026
2c85bea
fix warning
dimapihtar Mar 19, 2026
ffe53e1
fix unit test
dimapihtar Mar 19, 2026
6833d87
Revert "update nvrx version"
dimapihtar Mar 20, 2026
e26b2d2
Merge branch 'main' into sbak/ckpt_migrate
dimapihtar Mar 20, 2026
2a92e71
update nvrx version
dimapihtar Mar 20, 2026
47e190d
update uv.lock
dimapihtar Mar 20, 2026
c56a3eb
Fix issue setting up `CachedFileSystemReader` and incorrect use of ck…
sbak5 Mar 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def load(
)
merge(common_state_dict, sharded_objects)

loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
async_strategy = getattr(common_state_dict.get("args"), "async_strategy", "nvrx")
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir, async_strategy)

merge(common_state_dict, loaded_state_dict)

Expand Down Expand Up @@ -322,6 +323,7 @@ def save(
Callable[[CommonStateDict], StateDict]
] = None,
content_metadata: Optional[dict] = None,
async_strategy: Optional[str] = "nvrx",
) -> Optional[AsyncRequest]:
"""Saving entrypoint.

Expand Down Expand Up @@ -434,7 +436,7 @@ def metadata_finalize_fn():
raise CheckpointingException(
f'Cannot apply async_save to non-async strategy {sharded_strategy}'
)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir, async_strategy)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request

Expand Down
32 changes: 26 additions & 6 deletions megatron/core/dist_checkpointing/strategies/fully_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,22 @@ def __init__(

self.cached_distribution: Optional[ShardDistribution] = None

def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def async_save(
self,
sharded_state_dict: ShardedStateDict,
checkpoint_dir: Path,
async_strategy: str = "nvrx",
):
""" """
if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
)
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir, async_strategy)

def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" """
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.save(sharded_state_dict, checkpoint_dir)

Expand Down Expand Up @@ -135,6 +142,7 @@ def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) ->

@property
def can_handle_sharded_objects(self):
""" """
return self.base_strategy.can_handle_sharded_objects


Expand Down Expand Up @@ -185,7 +193,12 @@ def __init__(
self.cached_global_metadata: Optional[Metadata] = None

@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
def load(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove any load related changes. We don't have anything yet for loading.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbak5 load calss _get_filesystem_reader: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/dist_checkpointing/strategies/torch.py#L807

which uses CachedMetadataFileSystemReader: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/dist_checkpointing/strategies/torch.py#L766 so we need to import it properly in respect to async_strategy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm but I guess we can take it from common state dict
let me see

self,
sharded_state_dict: ShardedStateDict,
checkpoint_dir: Path,
async_strategy: str = "nvrx",
) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict.

Steps:
Expand Down Expand Up @@ -218,7 +231,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
loaded_state_dict = {}

if get_pg_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
return self.base_strategy.load(sharded_state_dict, checkpoint_dir, async_strategy)

# Step 1 and 2: exchange load metadata and distribute the load
with debug_time("self.apply_loading_parallelization", logger):
Expand All @@ -245,11 +258,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
), "sharded_state_dict is not empty after deferring tensors and objects"
with debug_time("base_load_ShardedObjects", logger):
# Load sharded objects first
loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir)
loaded_objects = self.base_strategy.load(
to_load_objects, checkpoint_dir, async_strategy
)

with debug_time("base_load_ShardedTensors", logger):
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir, async_strategy)

with debug_time("self.exchange_loaded_tensors", logger):

Expand Down Expand Up @@ -390,18 +405,23 @@ def apply_loading_parallelization(

@property
def can_handle_sharded_objects(self):
""" """
return self.base_strategy.can_handle_sharded_objects

def load_tensors_metadata(self, checkpoint_dir: Path):
""" """
return self.base_strategy.load_tensors_metadata(checkpoint_dir)

def load_sharded_metadata(self, checkpoint_dir: Path):
""" """
return self.base_strategy.load_sharded_metadata(checkpoint_dir)

def check_backend_compatibility(self, loaded_version):
""" """
return self.base_strategy.check_backend_compatibility(loaded_version)

def check_version_compatibility(self, loaded_version):
""" """
return self.base_strategy.check_version_compatibility(loaded_version)


Expand Down
Loading
Loading