-
Notifications
You must be signed in to change notification settings - Fork 251
feat: refactor mcore train/forward utilities #1654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,095
−551
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
70e81a4
refactor megatron data utils
ashors1 2c232d5
refactor train/forward utilities
ashors1 fc04228
add do_not_average_loss arg
ashors1 0ac1aa9
remove unused import
ashors1 1bd336e
add train unit tests
ashors1 afbbbe7
add unit tests
ashors1 ed26b9e
fix unit test
ashors1 390abba
add missing copyright header
ashors1 605f850
lint
ashors1 68459d2
lint
ashors1 3c19522
rebase & address feedback
ananthsub 5e40147
lint
ananthsub d5fd035
refactor
ananthsub 1cadb1e
address feedback
ananthsub 7f66e95
update
ananthsub 987b036
lint
ananthsub a54b079
unit tests
ananthsub 41a4e68
move metric aggregation to a function matching automodel
ananthsub 433bc3c
Merge branch 'main' into ashors/mcore-train
yuki-97 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Pipeline parallel utilities for Megatron models.""" | ||
|
|
||
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
| from megatron.core.parallel_state import ( | ||
| get_pipeline_model_parallel_group, | ||
| get_pipeline_model_parallel_last_rank, | ||
| get_pipeline_model_parallel_world_size, | ||
| is_pipeline_last_stage, | ||
| ) | ||
|
|
||
|
|
||
| def broadcast_obj_from_pp_rank(obj: Any) -> Any: | ||
| """Broadcast an object across pipeline parallel ranks. | ||
|
|
||
| This utility function handles broadcasting an object from the rank that owns it | ||
| to all other pipeline parallel ranks. If only one rank has the object (non-None), | ||
| it will be broadcast to all other ranks. | ||
|
|
||
| Args: | ||
| obj: The object to broadcast. Can be None on ranks that don't own it. | ||
|
|
||
| Returns: | ||
| The object on all ranks (either the original or the broadcast copy). | ||
|
|
||
| Raises: | ||
| ValueError: If the object doesn't exist on any pipeline parallel rank. | ||
| """ | ||
| pp_size = get_pipeline_model_parallel_world_size() | ||
| pp_group = get_pipeline_model_parallel_group() | ||
|
|
||
| if pp_size == 1: | ||
| return obj | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # 1. Gather presence flags from all PP ranks to find the source rank | ||
| # ------------------------------------------------------------------ | ||
| has_obj = obj is not None | ||
| obj_flags = [None] * pp_size | ||
| torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # 2. Identify the owning rank (the only rank with True flag) | ||
| # ------------------------------------------------------------------ | ||
| true_ranks = [rank for rank, flag in enumerate(obj_flags) if flag] | ||
| if not true_ranks: | ||
| raise ValueError("Object must exist on at least one PP rank") | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if len(true_ranks) > 1: | ||
| raise ValueError(f"Object present on multiple PP ranks: {true_ranks}") | ||
| src_rank = true_ranks[0] | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # 3. Broadcast the object from the source rank to all ranks | ||
| # ------------------------------------------------------------------ | ||
| # Use broadcast_object_list which is more robust than all_gather_object | ||
| obj_list = [obj] | ||
| pp_ranks = torch.distributed.get_process_group_ranks(pp_group) | ||
| global_src = pp_ranks[src_rank] | ||
| torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group) | ||
|
|
||
| return obj_list[0] | ||
|
|
||
|
|
||
| def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list: | ||
| """Broadcast loss metrics from the last pipeline stage to all stages. | ||
|
|
||
| This utility handles the common pattern where loss computation happens on the last | ||
| pipeline stage and needs to be broadcast to all other stages. | ||
|
|
||
| Args: | ||
| loss_metrics: List of loss metrics if on last stage, None otherwise | ||
|
|
||
| Returns: | ||
| List of loss metrics on all ranks | ||
| """ | ||
| pp_group = get_pipeline_model_parallel_group() | ||
| last_rank = get_pipeline_model_parallel_last_rank() | ||
|
|
||
| if is_pipeline_last_stage(ignore_virtual=True): | ||
| metrics_to_broadcast = [loss_metrics] | ||
| torch.distributed.broadcast_object_list( | ||
| metrics_to_broadcast, | ||
| src=last_rank, | ||
| group=pp_group, | ||
| ) | ||
| return loss_metrics | ||
| else: | ||
| metrics_to_broadcast = [None] | ||
| torch.distributed.broadcast_object_list( | ||
| metrics_to_broadcast, | ||
| src=last_rank, | ||
| group=pp_group, | ||
| ) | ||
| return metrics_to_broadcast[0] | ||
|
|
||
|
|
||
| def broadcast_tensors_from_last_stage( | ||
| tensors: dict[str, Optional[torch.Tensor]], | ||
| ) -> dict[str, torch.Tensor]: | ||
| """Broadcast multiple tensors from the last pipeline stage to all stages. | ||
|
|
||
| Args: | ||
| tensors: Dictionary mapping tensor names to tensors (None on non-last stages) | ||
| pp_group: Pipeline parallel group (auto-detected if None) | ||
|
|
||
| Returns: | ||
| Dictionary of broadcasted tensors on all ranks | ||
| """ | ||
| pp_group = get_pipeline_model_parallel_group() | ||
|
|
||
| from nemo_rl.models.megatron.common import broadcast_tensor | ||
|
|
||
| last_rank = get_pipeline_model_parallel_last_rank() | ||
| current_rank = torch.distributed.get_rank() | ||
|
|
||
| broadcasted_tensors = {} | ||
|
|
||
| if is_pipeline_last_stage(ignore_virtual=True): | ||
| # Broadcast tensors from last stage | ||
| for name, tensor in tensors.items(): | ||
| if tensor is None: | ||
| raise ValueError( | ||
| f"Last PP stage must provide tensor '{name}' for broadcast." | ||
| ) | ||
| broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group) | ||
| else: | ||
| # Receive tensors on other stages | ||
| for name in tensors.keys(): | ||
| broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group) | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return broadcasted_tensors | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.