feat: refactor mcore train/forward utilities#1654
Conversation
53cd230 to
5417491
Compare
f510015 to
2820fd4
Compare
a2a8a51 to
fc94f8b
Compare
terrykong
left a comment
There was a problem hiding this comment.
thanks for all the refactoring work @ashors1
- I think we should merge this after @asolergi-nv looks into the padding issue in the packing path, since it would be good to check if this PR doesn't introduce a regression there
cc @ananthsub
yuki-97
left a comment
There was a problem hiding this comment.
thank you so much for the refactor efforts! just one last step to complete. 🎉
2820fd4 to
5fd527d
Compare
📝 WalkthroughWalkthroughThis PR refactors Megatron-based training infrastructure by removing Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Training Loop
participant FWD as megatron_forward_backward
participant MF as model_forward
participant PP as Post-Processor<br/>(Loss/Logprobs/TopK)
participant BC as Pipeline Parallel<br/>Broadcast
participant Stages as PP Stages
Client->>FWD: Call with data_iterator,<br/>post_processing_fn
FWD->>FWD: Create forward_step<br/>partial
FWD->>FWD: Call Megatron<br/>forward_backward
Note over FWD: Executes across<br/>pipeline stages
FWD->>MF: model_forward on stage
MF-->>FWD: logits (on last stage)
FWD->>PP: Apply post-processing<br/>(e.g., compute loss)
PP-->>FWD: processed output
FWD->>BC: broadcast_loss_metrics<br/>or broadcast_tensors
BC->>Stages: Gather from last stage
Stages-->>BC: Distribute to all stages
BC-->>FWD: Broadcasted result
FWD-->>Client: Final output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/megatron/common.py (1)
60-109:⚠️ Potential issue | 🔴 CriticalFix device metadata handling in
broadcast_tensorto avoidTypeErrorat line 108 and wrong-device tensor allocation.Line 63 gets
torch.cuda.current_device()(returns int), then line 108 callstorch.device(device)which will raiseTypeErrorsincetorch.device()does not accept int. Additionally, device metadata is not broadcast (line 75 only includes shape and dtype), so non-source ranks always assume the default CUDA device. If the source tensor is on a non-default device, non-source ranks allocate and validate against the wrong device.Broadcast the device in metadata and use the received device directly in comparisons:
🔧 Suggested fix
- # Assume operations happen on the default CUDA device for the rank - # TODO: Consider making device explicit if needed, e.g., derive from tensor on src - device = torch.cuda.current_device() + # Device will be broadcast as part of metadata + device = None @@ - metadata = [tensor.shape, tensor.dtype] + metadata = [tensor.shape, tensor.dtype, tensor.device] @@ - received_shape, received_dtype = object_list[0] + received_shape, received_dtype, received_device = object_list[0] @@ - tensor = torch.empty(received_shape, dtype=received_dtype, device=device) + tensor = torch.empty(received_shape, dtype=received_dtype, device=received_device) @@ - if tensor.device != torch.device(device): + if tensor.device != received_device: raise ValueError( f"Rank {rank}: Provided tensor is on device {tensor.device}, " - f"but expected broadcast device is {device}." + f"but expected broadcast device is {received_device}." )
🤖 Fix all issues with AI agents
In `@nemo_rl/models/megatron/pipeline_parallel.py`:
- Around line 1-13: Update the NVIDIA copyright header year from 2025 to 2026 at
the top of the file (the existing license block in the current module) so the
header reads 2026, ensuring the rest of the Apache License text remains
unchanged; modify the banner in nemo_rl/models/megatron/pipeline_parallel.py
accordingly.
- Around line 53-67: The code currently selects the first True in obj_flags
without ensuring uniqueness; update the logic around obj_flags/pp_size/pp_group
to validate exactly one rank owns the object by counting True entries in
obj_flags (or collecting indices) and raise an error if count != 1, then set
src_rank to the single owning rank; preserve existing behavior of raising when
none exist but also raise when multiple ranks have the object to enforce the
function's contract (use the variables obj_flags, src_rank, and pp_group to
locate and implement this check).
- Around line 135-147: The last pipeline stage currently skips broadcast calls
when a tensor value is None, causing a deadlock because other stages still call
broadcast_tensor; update the is_pipeline_last_stage branch in the loop over
tensors so that for every key you call broadcast_tensor even if tensors[name] is
None—e.g., replace None with a sentinel empty tensor/object appropriate for your
dtype/device before calling broadcast_tensor(tensor, current_rank, pp_group)
(use the same device/dtype logic as in broadcast_tensor); ensure
broadcasted_tensors[name] gets the return value so the collectives run on the
last stage as well (refer to is_pipeline_last_stage, broadcast_tensor, tensors,
broadcasted_tensors, current_rank, last_rank, pp_group).
In `@nemo_rl/models/megatron/train.py`:
- Around line 1-13: Update the NVIDIA copyright header in
nemo_rl/models/megatron/train.py by changing the year from 2025 to 2026 in the
file header; ensure the top comment block (the license header in train.py)
reflects "Copyright (c) 2026, NVIDIA CORPORATION" and keeps the rest of the
Apache 2.0 header unchanged.
- Around line 50-100: The call to apply_temperature_scaling in model_forward
unconditionally modifies logits (affecting training loss); change it so
temperature scaling runs only for inference/post‑processing paths: in
model_forward (and the similar block at lines ~165-175), detect whether
generation/inference postprocessing is active (e.g., check cfg["generation"] and
the postprocessor type or a flag like cfg.postprocessor == "inference" /
cfg.get("mode") == "inference") and only call
apply_temperature_scaling(output_tensor, cfg) when that condition is true;
otherwise leave logits unchanged for training.
- Around line 428-431: The commented-out pipeline-parallel guard around
is_pipeline_last_stage(ignore_virtual=True) and the return of
output_tensor.new_zeros(()) should be either removed or documented: either
delete those three commented lines if they are no longer needed, or replace them
with an explanatory comment that names is_pipeline_last_stage and
output_tensor.new_zeros and states why the PP guard is intentionally disabled
(e.g., because all PP stages now produce logits, testing reasons, or a temporary
debugging bypass) so future readers know the rationale.
In `@tests/unit/models/megatron/test_train.py`:
- Around line 343-351: The test assigns the return of megatron_forward_backward
to an unused variable result, triggering Ruff F841; fix it by either removing
the assignment and calling megatron_forward_backward(...) directly or by
assigning to a deliberately ignored name (e.g., _ ) so the return value is not
flagged as unused—update the call site where result is set (the
megatron_forward_backward invocation) accordingly.
🧹 Nitpick comments (1)
nemo_rl/models/megatron/data.py (1)
72-80: Align type hints with optionalstraggler_timerand actual return type.
process_microbatchnow acceptsOptional[StragglerDetector]and returnsProcessedInputs, but the upstream signatures still require a non-Optional timer, and the return annotation still advertises a tuple. This makes type checking misleading.🔧 Suggested fix
def make_processed_microbatch_iterator( raw_iterator: Iterator[BatchedDataDict[Any]], cfg: dict[str, Any], seq_length_key: Optional[str], pad_individual_seqs_to_multiple_of: int, pad_packed_seq_to_multiple_of: int, - straggler_timer: StragglerDetector, + straggler_timer: Optional[StragglerDetector], pad_full_seq_to: Optional[int], ) -> Iterator[ProcessedMicrobatch]: @@ def get_microbatch_iterator( data: BatchedDataDict[Any], cfg: dict[str, Any], mbs: int, - straggler_timer: StragglerDetector, + straggler_timer: Optional[StragglerDetector], seq_length_key: Optional[str] = None, ) -> Tuple[Iterator[ProcessedMicrobatch], int, int, int, int]: @@ def process_microbatch( @@ - straggler_timer: Optional[StragglerDetector] = None, -) -> tuple[ - torch.Tensor, - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[PackedSeqParams], - Optional[torch.Tensor], -]: + straggler_timer: Optional[StragglerDetector] = None, +) -> ProcessedInputs:Also applies to: 126-132, 208-216
d63506c to
169ad3a
Compare
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
4d5a5cd to
41a4e68
Compare
|
cancelled CI for now and waiting #1902 to be merged first. |
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Nightly test results:
Issues
Closes #1593.
Closes #1744.
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Refactor
Tests