diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 28af36b11b..69912f209c 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -11,28 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Any, Iterator, Optional + +from typing import Any, Optional import torch import torch.distributed as dist -from megatron.bridge.training.state import GlobalState -from megatron.core.models.gpt import GPTModel -from megatron.core.parallel_state import ( - get_context_parallel_group, - get_context_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, -) from megatron.core.transformer.moe.moe_utils import ( clear_aux_losses_tracker, get_moe_layer_wise_logging_tracker, reduce_aux_losses_tracker_across_ranks, ) -from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper -from nemo_rl.distributed.batched_data_dict import BatchedDataDict - def _round_up_to_multiple(value: int, multiple: int) -> int: return ( @@ -42,119 +31,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int: ) -def forward_step_arbitrary_loss( - state: GlobalState, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - data_iterator: Iterator[BatchedDataDict[Any]], - model: GPTModel, - loss_fn: LossFunction, - pack_sequences: bool = False, - defer_fp32_logits: Optional[bool] = None, - cp_normalize: bool = True, - policy_cfg: Optional[dict] = None, -): - """Forward training step with support for packed sequences and context parallelism. - - Args: - state (GlobalState): Global state for the run - global_valid_seqs: Global count of valid sequences - global_valid_toks: Global count of valid tokens - data_iterator: Input data iterator - model (GPTModel): The GPT Model - loss_fn (LossFunction): Loss function to apply - pack_sequences (bool): Whether to pack sequences for efficiency - defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 - cp_normalize (bool): Whether to normalize the loss by the cp_size - policy_cfg (Optional[dict]): Policy configuration containing generation parameters - - Notes on packed sequences with context parallelism (CP): - - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - - The factor of 2 ensures load balancing for causal attention - - cu_seqlens tracks actual sequence boundaries - - cu_seqlens_padded tracks padded sequence boundaries for CP - - Requires TransformerEngine >= 1.10 for CP support - """ - straggler_timer = state.straggler_timer - - # Get the pre-processed microbatch from the iterator - processed_mb = next(data_iterator) - - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids_cp_sharded.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - # Mamba models currently do not support packed_seq_params - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - if defer_fp32_logits: - additional_kwargs["fp32_output"] = False - - with straggler_timer: - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **additional_kwargs, - **multimodal_data, - ) - - # Apply temperature scaling to logits for training - # This matches the dtensor worker's _apply_temperature_scaling in the train method - if ( - policy_cfg is not None - and "generation" in policy_cfg - and policy_cfg["generation"] is not None - ): - output_tensor.div_(policy_cfg["generation"]["temperature"]) - - # Unpack the output tensor if we did packed sequences - if pack_sequences and packed_seq_params is not None: - # remove padding - loss_fn = SequencePackingLossWrapper( - loss_fn=loss_fn, - cu_seqlens_q=packed_seq_params.cu_seqlens_q, - cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, - ) - - loss_data = data_dict - - loss_fn_wrapped = partial( - loss_fn, - data=loss_data, - global_valid_seqs=global_valid_seqs, - global_valid_toks=global_valid_toks, - vocab_parallel_rank=get_tensor_model_parallel_rank(), - vocab_parallel_group=get_tensor_model_parallel_group(), - context_parallel_group=get_context_parallel_group(), - ) - - if cp_normalize: - cp_size = get_context_parallel_world_size() - orig_loss_fn_wrapped = loss_fn_wrapped - - def _div_by_cp_size(*args, **kwargs): - loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) - return loss / cp_size, metrics - - loss_fn_wrapped = _div_by_cp_size - - return output_tensor, loss_fn_wrapped - - def broadcast_tensor( tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup ) -> torch.Tensor: diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index 5adbec29c7..7c765f19b5 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Iterator, Optional, Tuple @@ -211,7 +212,7 @@ def process_microbatch( pad_packed_seq_to_multiple_of: int = 1, pad_full_seq_to: Optional[int] = None, pack_sequences: bool = False, - straggler_timer: StragglerDetector = None, + straggler_timer: Optional[StragglerDetector] = None, ) -> tuple[ torch.Tensor, torch.Tensor, @@ -221,7 +222,8 @@ def process_microbatch( Optional[torch.Tensor], ]: """Process a microbatch for Megatron model forward pass.""" - with straggler_timer(bdata=True): + ctx = straggler_timer(bdata=True) if straggler_timer is not None else nullcontext() + with ctx: input_ids = data_dict["input_ids"] attention_mask = None position_ids = None @@ -294,15 +296,15 @@ def process_global_batch( *, batch_idx: int, batch_size: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> dict[str, Any]: """Process a global batch and compute normalization factors. Args: - data: Full dataset + data: Full dataset to extract a batch from + loss_fn: Loss function (used to check loss type for token-level validation) + dp_group: Data parallel process group for all-reduce batch_idx: Index of batch to extract batch_size: Size of batch to extract - loss_fn: Loss function (used to check loss type) - dp_mesh: Data parallel mesh Returns: Dictionary containing: diff --git a/nemo_rl/models/megatron/pipeline_parallel.py b/nemo_rl/models/megatron/pipeline_parallel.py new file mode 100644 index 0000000000..7728f80f65 --- /dev/null +++ b/nemo_rl/models/megatron/pipeline_parallel.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline parallel utilities for Megatron models.""" + +from typing import Any, Optional + +import torch +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_last_rank, + get_pipeline_model_parallel_world_size, + is_pipeline_last_stage, +) + + +def broadcast_obj_from_pp_rank(obj: Any) -> Any: + """Broadcast an object across pipeline parallel ranks. + + This utility function handles broadcasting an object from the rank that owns it + to all other pipeline parallel ranks. If only one rank has the object (non-None), + it will be broadcast to all other ranks. + + Args: + obj: The object to broadcast. Can be None on ranks that don't own it. + + Returns: + The object on all ranks (either the original or the broadcast copy). + + Raises: + ValueError: If the object doesn't exist on any pipeline parallel rank. + """ + pp_size = get_pipeline_model_parallel_world_size() + pp_group = get_pipeline_model_parallel_group() + + if pp_size == 1: + return obj + + # ------------------------------------------------------------------ + # 1. Gather presence flags from all PP ranks to find the source rank + # ------------------------------------------------------------------ + has_obj = obj is not None + obj_flags = [None] * pp_size + torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group) + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with True flag) + # ------------------------------------------------------------------ + true_ranks = [rank for rank, flag in enumerate(obj_flags) if flag] + if not true_ranks: + raise ValueError("Object must exist on at least one PP rank") + if len(true_ranks) > 1: + raise ValueError(f"Object present on multiple PP ranks: {true_ranks}") + src_rank = true_ranks[0] + + # ------------------------------------------------------------------ + # 3. Broadcast the object from the source rank to all ranks + # ------------------------------------------------------------------ + # Use broadcast_object_list which is more robust than all_gather_object + obj_list = [obj] + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + global_src = pp_ranks[src_rank] + torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group) + + return obj_list[0] + + +def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list: + """Broadcast loss metrics from the last pipeline stage to all stages. + + This utility handles the common pattern where loss computation happens on the last + pipeline stage and needs to be broadcast to all other stages. + + Args: + loss_metrics: List of loss metrics if on last stage, None otherwise + + Returns: + List of loss metrics on all ranks + """ + pp_group = get_pipeline_model_parallel_group() + last_rank = get_pipeline_model_parallel_last_rank() + + if is_pipeline_last_stage(ignore_virtual=True): + metrics_to_broadcast = [loss_metrics] + torch.distributed.broadcast_object_list( + metrics_to_broadcast, + src=last_rank, + group=pp_group, + ) + return loss_metrics + else: + metrics_to_broadcast = [None] + torch.distributed.broadcast_object_list( + metrics_to_broadcast, + src=last_rank, + group=pp_group, + ) + return metrics_to_broadcast[0] + + +def broadcast_tensors_from_last_stage( + tensors: dict[str, Optional[torch.Tensor]], +) -> dict[str, torch.Tensor]: + """Broadcast multiple tensors from the last pipeline stage to all stages. + + Args: + tensors: Dictionary mapping tensor names to tensors (None on non-last stages) + pp_group: Pipeline parallel group (auto-detected if None) + + Returns: + Dictionary of broadcasted tensors on all ranks + """ + pp_group = get_pipeline_model_parallel_group() + + from nemo_rl.models.megatron.common import broadcast_tensor + + last_rank = get_pipeline_model_parallel_last_rank() + current_rank = torch.distributed.get_rank() + + broadcasted_tensors = {} + + if is_pipeline_last_stage(ignore_virtual=True): + # Broadcast tensors from last stage + for name, tensor in tensors.items(): + if tensor is None: + raise ValueError( + f"Last PP stage must provide tensor '{name}' for broadcast." + ) + broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group) + else: + # Receive tensors on other stages + for name in tensors.keys(): + broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group) + + return broadcasted_tensors diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py new file mode 100644 index 0000000000..95ccc3761d --- /dev/null +++ b/nemo_rl/models/megatron/train.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import torch +from megatron.core.models.gpt import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.utils import StragglerDetector + +from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + allgather_cp_sharded_tensor, + distributed_vocab_topk, + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, +) +from nemo_rl.models.megatron.data import ProcessedMicrobatch +from nemo_rl.models.policy import PolicyConfig + +# Union type for any post-processing function (defined after classes below) +PostProcessingFunction = Union[ + "LossPostProcessor", + "LogprobsPostProcessor", + "TopkLogitsPostProcessor", +] + + +def model_forward( + model: GPTModel, + data_dict: BatchedDataDict[Any], + cfg: PolicyConfig, + input_ids_cp_sharded: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + defer_fp32_logits: Optional[bool] = False, + straggler_timer: Optional[StragglerDetector] = None, +) -> torch.Tensor: + """Perform a single forward pass through the model. + + Args: + model: The model to run forward pass on + data_dict: Dictionary containing batch data + cfg: Policy configuration dictionary + input_ids_cp_sharded: Context-parallel sharded input token IDs + position_ids: Position IDs for tokens + attention_mask: Attention mask for the sequence + packed_seq_params: Parameters for packed sequences (optional) + defer_fp32_logits: Whether to skip the conversion of logits to fp32 + straggler_timer: Straggler detector for profiling the forward pass + + Returns: + torch.Tensor: Output tensor from the model (logits) + """ + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids_cp_sharded.device + ) + if len(multimodal_data) > 0: + position_ids = None + + additional_kwargs = {} + # Mamba models currently do not support packed_seq_params + if packed_seq_params is not None: + additional_kwargs["packed_seq_params"] = packed_seq_params + if defer_fp32_logits: + additional_kwargs["fp32_output"] = False + + with straggler_timer() if straggler_timer is not None else nullcontext(): + output_tensor = model( + input_ids=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, + **additional_kwargs, + **multimodal_data, + ) + + return output_tensor + + +def apply_temperature_scaling( + logits: torch.Tensor, + cfg: PolicyConfig, +) -> torch.Tensor: + """Apply temperature scaling to logits. + + Args: + logits: Logits tensor to scale + cfg: Policy configuration containing generation settings + + Returns: + torch.Tensor: Temperature-scaled logits + """ + if "generation" in cfg and cfg["generation"] is not None: + logits.div_(cfg["generation"]["temperature"]) + return logits + + +def forward_with_post_processing_fn( + data_iterator: Iterator[ProcessedMicrobatch], + model: GPTModel, + cfg: PolicyConfig, + post_processing_fn: PostProcessingFunction, + defer_fp32_logits: Optional[bool] = False, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + straggler_timer: Optional[StragglerDetector] = None, +) -> Tuple[torch.Tensor, Callable]: + """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. + + This function takes a pre-processed microbatch (with sequence packing already handled), + runs the forward step through the model, and prepares a post-processing function for + post-processing the outputs. + + Args: + data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) + model: The model to run forward pass on + cfg: Policy configuration dictionary + post_processing_fn: Post-processing function to post-process the logits + defer_fp32_logits: Whether to defer FP32 conversion of logits + global_valid_seqs: Global valid sequence count for loss normalization + global_valid_toks: Global valid token count for loss normalization + straggler_timer: Straggler detector for profiling the forward pass + + Returns: + tuple: (output_tensor, post_processing_fn_wrapped) + - output_tensor: Raw model outputs (logits) + - post_processing_fn_wrapped: Function to create output post-processing function when called + """ + # Get the pre-processed microbatch from the iterator + processed_mb = next(data_iterator) + + # Extract the processed components + data_dict = processed_mb.data_dict + input_ids = processed_mb.input_ids + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded + attention_mask = processed_mb.attention_mask + position_ids = processed_mb.position_ids + packed_seq_params = processed_mb.packed_seq_params + cu_seqlens_padded = processed_mb.cu_seqlens_padded + + output_tensor = model_forward( + model=model, + data_dict=data_dict, + cfg=cfg, + input_ids_cp_sharded=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + defer_fp32_logits=defer_fp32_logits, + straggler_timer=straggler_timer, + ) + + # Apply temperature scaling only for sampling-oriented post-processors. + # Loss computation should use unscaled logits. + if isinstance( + post_processing_fn, + (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor), + ): + apply_temperature_scaling(output_tensor, cfg) + + # Use type checking to dispatch to the correct post-processing method + if isinstance(post_processing_fn, LossPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + packed_seq_params=packed_seq_params, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + ) + elif isinstance(post_processing_fn, LogprobsPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + input_ids=input_ids, + cu_seqlens_padded=cu_seqlens_padded, + ) + elif isinstance(post_processing_fn, TopkLogitsPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + cu_seqlens_padded=cu_seqlens_padded, + ) + else: + raise TypeError( + f"Unknown post-processing function type: {type(post_processing_fn)}" + ) + + return output_tensor, post_processing_fn_wrapped + + +def megatron_forward_backward( + model: GPTModel, + cfg: PolicyConfig, + data_iterator: Iterator[ProcessedMicrobatch], + num_microbatches: int, + seq_length: int, + mbs: int, + post_processing_fn: PostProcessingFunction, + forward_only: bool = False, + defer_fp32_logits: Optional[bool] = False, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + do_not_average_loss: bool = False, + straggler_timer: Optional[StragglerDetector] = None, +) -> Any: + """Execute forward and backward passes using Megatron's utilities. + + This is the main training loop function that coordinates forward and backward + passes across multiple microbatches using Megatron's pipeline parallel + execution framework. + + Args: + model: The model to train + cfg: Policy configuration dictionary + data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) + num_microbatches: Number of microbatches to process + seq_length: Sequence length + mbs: Micro batch size + post_processing_fn: Post-processing function to post-process the logits + forward_only: If True, skip backward pass + defer_fp32_logits: Whether to skip the conversion of logits to fp32 + global_valid_seqs: Global valid sequence count for loss normalization + global_valid_toks: Global valid token count for loss normalization + do_not_average_loss: If True, do not average loss across microbatches + straggler_timer: Straggler detector for profiling the forward pass + + Returns: + Results from the forward/backward execution + """ + forward_step = partial( + forward_with_post_processing_fn, + cfg=cfg, + post_processing_fn=post_processing_fn, + defer_fp32_logits=defer_fp32_logits, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + straggler_timer=straggler_timer, + ) + forward_backward_func = get_forward_backward_func() + return forward_backward_func( + forward_step_func=forward_step, + data_iterator=data_iterator, + model=model, + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=mbs, + decoder_seq_length=seq_length, + forward_only=forward_only, + do_not_average_loss=do_not_average_loss, + ) + + +class LossPostProcessor: + def __init__( + self, + loss_fn: LossFunction, + cfg: PolicyConfig, + cp_normalize: bool = True, + ): + self.loss_fn = loss_fn + self.cfg = cfg + self.cp_normalize = cp_normalize + + def __call__( + self, + data_dict: BatchedDataDict[Any], + packed_seq_params: Optional[PackedSeqParams] = None, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, Any]]]: + """Create a loss post-processing function for training. + + This function wraps a loss function with the necessary context and parameters + to compute loss and metrics from model outputs. It handles sequence packing + and context parallelism normalization. + + Args: + data_dict: Batched data dictionary for the current microbatch + packed_seq_params: Parameters for packed sequences (optional) + global_valid_seqs: Global valid sequence count for loss normalization + global_valid_toks: Global valid token count for loss normalization + + Returns: + Callable: Function that takes output tensor and returns (loss, metrics) tuple + """ + loss_fn = self.loss_fn + pack_sequences = self.cfg["sequence_packing"]["enabled"] + if pack_sequences and packed_seq_params is not None: + # remove padding + loss_fn = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=packed_seq_params.cu_seqlens_q, + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, + ) + + loss_fn_wrapped = partial( + loss_fn, + data=data_dict, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), + ) + + if self.cp_normalize: + cp_size = get_context_parallel_world_size() + orig_loss_fn_wrapped = loss_fn_wrapped + + def _div_by_cp_size(*args, **kwargs): + loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) + return loss / cp_size, metrics + + loss_fn_wrapped = _div_by_cp_size + + return loss_fn_wrapped + + +class LogprobsPostProcessor: + def __init__(self, cfg: PolicyConfig): + self.cfg = cfg + + def __call__( + self, + data_dict: BatchedDataDict[Any], + input_ids: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """Create a post-processing function that computes token log probabilities. + + This function returns a processor that takes model logits and converts them + to token-level log probabilities, handling both packed and unpacked sequences. + + Args: + data_dict: Batched data dictionary containing input sequences + input_ids: Processed input token IDs + cu_seqlens_padded: Cumulative sequence lengths for packed sequences + + Returns: + Callable: Function that takes output tensor and returns (dummy_loss, {"logprobs": token_logprobs}) + """ + unpacked_input_ids = data_dict["input_ids"] + original_seq_length = unpacked_input_ids.shape[1] + + def processor_fn_inner(output_tensor): + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) + if self.cfg["sequence_packing"]["enabled"]: + token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( + output_tensor, + target=input_ids, + cu_seqlens_padded=cu_seqlens_padded, + unpacked_seqlen=original_seq_length, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + group=tp_grp, + inference_only=True, + cp_group=get_context_parallel_group(), + chunk_size=logprob_chunk_size, + ) + else: + token_logprobs = from_parallel_logits_to_logprobs( + output_tensor, + target=unpacked_input_ids, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + tp_group=tp_grp, + inference_only=True, + chunk_size=logprob_chunk_size, + ) + + # Prepend 0 logprob for first token to maintain same sequence length as input + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + return torch.tensor(0.0, device=token_logprobs.device), { + "logprobs": token_logprobs + } + + return processor_fn_inner + + +class TopkLogitsPostProcessor: + def __init__(self, cfg: PolicyConfig, k: int): + self.cfg = cfg + self.k = k + + def __call__( + self, + data_dict: BatchedDataDict[Any], + cu_seqlens_padded: torch.Tensor, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """Create a post-processing function that computes top-k logits and indices. + + This function returns a processor that extracts the top-k highest logits + and their corresponding vocabulary indices from model outputs. It handles + tensor parallelism, context parallelism, and sequence packing. + + Args: + data_dict: Batched data dictionary + cu_seqlens_padded: Cumulative sequence lengths for packed sequences + + Returns: + Callable: Function that takes output tensor and returns + (dummy_loss, {"topk_logits": values, "topk_indices": indices}) + """ + pack = self.cfg["sequence_packing"]["enabled"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + unpacked_seqlen = data_dict["input_ids"].shape[1] + seq_lengths = data_dict["input_lengths"] + + def processor_fn_inner(output_tensor): + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + vocab_shard_size = output_tensor.shape[-1] + vocab_start_index = tp_rank * vocab_shard_size + + chunk_size = None + if "logprob_chunk_size" in self.cfg: + chunk_size = self.cfg["logprob_chunk_size"] + + topk_vals_local, topk_idx_local = distributed_vocab_topk( + output_tensor, + self.k, + tp_grp, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_start_index + vocab_shard_size, + chunk_size=chunk_size, + ) + + if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: + cp_grp = get_context_parallel_group() + if pack: + # Per-sequence CP allgather following packed-sequence logic + batch_size = data_dict["input_ids"].shape[0] + total_packed_len = int(cu_seqlens_padded[-1].item()) + + topk_vals_full = torch.zeros( + (1, total_packed_len, self.k), + dtype=topk_vals_local.dtype, + device=topk_vals_local.device, + ) + topk_idx_full = torch.zeros( + (1, total_packed_len, self.k), + dtype=topk_idx_local.dtype, + device=topk_idx_local.device, + ) + + for i in range(batch_size): + start_idx = int(cu_seqlens_padded[i].item()) + end_idx = int(cu_seqlens_padded[i + 1].item()) + if end_idx > start_idx: + local_vals_slice = topk_vals_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + local_idx_slice = topk_idx_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + gathered_vals = allgather_cp_sharded_tensor( + local_vals_slice, cp_grp, seq_dim=1 + ) + gathered_idx = allgather_cp_sharded_tensor( + local_idx_slice, cp_grp, seq_dim=1 + ) + # Some kernels may return [X, Y, k] where X*Y = (end_idx - start_idx). + # Flatten leading dims and reshape to [1, expected_len, k] to match target. + expected_len = end_idx - start_idx + if ( + gathered_vals.dim() == 3 + and gathered_vals.shape[1] != expected_len + ): + gathered_vals = gathered_vals.reshape( + 1, expected_len, gathered_vals.shape[-1] + ) + if ( + gathered_idx.dim() == 3 + and gathered_idx.shape[1] != expected_len + ): + gathered_idx = gathered_idx.reshape( + 1, expected_len, gathered_idx.shape[-1] + ) + topk_vals_full[:, start_idx:end_idx, :] = gathered_vals + topk_idx_full[:, start_idx:end_idx, :] = gathered_idx + else: + # Sequence packing must be enabled when CP > 1 + raise RuntimeError( + "Context Parallelism (CP>1) requires sequence packing to be enabled." + ) + else: + topk_vals_full = topk_vals_local + topk_idx_full = topk_idx_local + + if pack: + batch_size = data_dict["input_ids"].shape[0] + out_vals = torch.zeros( + (batch_size, unpacked_seqlen, self.k), + dtype=topk_vals_full.dtype, + device=topk_vals_full.device, + ) + out_idx = torch.zeros( + (batch_size, unpacked_seqlen, self.k), + dtype=topk_idx_full.dtype, + device=topk_idx_full.device, + ) + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + start_idx = int(cu_seqlens_padded[i].item()) + if seq_len > 0: + out_vals[i, :seq_len, :] = topk_vals_full[ + 0, start_idx : start_idx + seq_len, : + ] + out_idx[i, :seq_len, :] = topk_idx_full[ + 0, start_idx : start_idx + seq_len, : + ] + return output_tensor.new_zeros(()), { + "topk_logits": out_vals, + "topk_indices": out_idx, + } + else: + return output_tensor.new_zeros(()), { + "topk_logits": topk_vals_full, + "topk_indices": topk_idx_full, + } + + return processor_fn_inner + + +def aggregate_training_statistics( + all_mb_metrics: List[Dict[str, Any]], + losses: List[float], + data_parallel_group: torch.distributed.ProcessGroup, +) -> Tuple[Dict[str, List[Any]], torch.Tensor]: + """Aggregate training statistics across microbatches and data-parallel ranks. + + Computes a global loss by all-reducing per-gradient-buffer losses across the + data-parallel group, then collects per-microbatch metrics into lists keyed by + metric name. + + Args: + all_mb_metrics: List of metric dicts from each microbatch. + losses: List of per-gradient-buffer scalar losses on this rank. + data_parallel_group: The data-parallel process group for all-reduce. + + Returns: + Tuple of: + - mb_metrics: Dict mapping metric names to lists of values across microbatches. + - global_loss: Tensor of losses summed across all data-parallel ranks. + """ + # Compute global loss across all data-parallel ranks + with torch.no_grad(): + global_loss = torch.tensor(losses, device="cuda") + torch.distributed.all_reduce( + global_loss, + op=torch.distributed.ReduceOp.SUM, + group=data_parallel_group, + ) + + # Aggregate metrics across all microbatches + mb_metrics: Dict[str, List[Any]] = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + return dict(mb_metrics), global_loss diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index e1fcc27e65..d9a1c3d8a3 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -14,11 +14,9 @@ import gc import os import re -import time import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from functools import partial from typing import Any, Iterator, Optional, TypeVar, cast import ray @@ -44,29 +42,16 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.models.gpt import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( - get_context_parallel_group, get_pipeline_model_parallel_group, - get_pipeline_model_parallel_last_rank, - get_pipeline_model_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, is_pipeline_last_stage, ) -from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine from transformers import PreTrainedTokenizerBase from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - allgather_cp_sharded_tensor, - distributed_vocab_topk, - from_parallel_logits_to_logprobs, - from_parallel_logits_to_logprobs_packed_sequences, -) from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, @@ -74,16 +59,17 @@ verify_right_padding, ) from nemo_rl.models.generation.vllm.config import VllmConfig -from nemo_rl.models.megatron.common import ( - broadcast_tensor, - forward_step_arbitrary_loss, - get_moe_metrics, -) +from nemo_rl.models.megatron.common import get_moe_metrics from nemo_rl.models.megatron.config import MegatronGenerationConfig from nemo_rl.models.megatron.data import ( get_microbatch_iterator, process_global_batch, ) +from nemo_rl.models.megatron.pipeline_parallel import ( + broadcast_loss_metrics_from_last_stage, + broadcast_obj_from_pp_rank, + broadcast_tensors_from_last_stage, +) from nemo_rl.models.megatron.setup import ( finalize_megatron_setup, handle_model_import, @@ -93,6 +79,13 @@ validate_and_set_config, validate_model_paths, ) +from nemo_rl.models.megatron.train import ( + LogprobsPostProcessor, + LossPostProcessor, + TopkLogitsPostProcessor, + aggregate_training_statistics, + megatron_forward_backward, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( ColocatablePolicyInterface, @@ -107,59 +100,6 @@ TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) -def broadcast_object_across_pp_ranks(obj): - """Broadcast an object across pipeline parallel ranks. - - This utility function handles broadcasting an object from the rank that owns it - to all other pipeline parallel ranks. If only one rank has the object (non-None), - it will be broadcast to all other ranks. - - Args: - obj: The object to broadcast. Can be None on ranks that don't own it. - - Returns: - The object on all ranks (either the original or the broadcast copy). - - Raises: - ValueError: If the object doesn't exist on any pipeline parallel rank. - """ - pp_size = get_pipeline_model_parallel_world_size() - pp_group = get_pipeline_model_parallel_group() - - if pp_size == 1: - return obj - - # ------------------------------------------------------------------ - # 1. Gather presence flags from all PP ranks to find the source rank - # ------------------------------------------------------------------ - has_obj = obj is not None - obj_flags = [None] * pp_size - torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group) - - # ------------------------------------------------------------------ - # 2. Identify the owning rank (the only rank with True flag) - # ------------------------------------------------------------------ - src_rank = None # Rank *inside* the PP group - for rank, flag in enumerate(obj_flags): - if flag: - src_rank = rank - break - - if src_rank is None: - raise ValueError("Object must exist on at least one PP rank") - - # ------------------------------------------------------------------ - # 3. Broadcast the object from the source rank to all ranks - # ------------------------------------------------------------------ - # Use broadcast_object_list which is more robust than all_gather_object - obj_list = [obj] - pp_ranks = torch.distributed.get_process_group_ranks(pp_group) - global_src = pp_ranks[src_rank] - torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group) - - return obj_list[0] - - @ray.remote( runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker") ) # pragma: no cover @@ -311,7 +251,8 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - self.model.zero_grad_buffer() + # Note: zero_grad_buffer is called at the start of each global batch iteration + # in the loop below, so we don't need to call it here. if hasattr(self.model, "inference_params"): self.model.inference_params = None @@ -344,9 +285,6 @@ def train( self.model.train() with ctx: - forward_step = partial( - forward_step_arbitrary_loss, loss_fn=loss_fn, policy_cfg=self.cfg - ) all_mb_metrics = [] losses = [] total_num_microbatches = 0 @@ -377,6 +315,11 @@ def train( # Track total microbatches for MoE aux-loss averaging total_num_microbatches += int(num_microbatches) + loss_post_processor = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + ) + rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): # Set grad to zero. @@ -384,24 +327,20 @@ def train( self.optimizer.zero_grad() # Forward pass. - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=partial( - forward_step, - self.mcore_state, - global_valid_seqs, - global_valid_toks, - pack_sequences=self.cfg["sequence_packing"]["enabled"], - defer_fp32_logits=self.defer_fp32_logits, - ), - data_iterator=data_iterator, + losses_reduced = megatron_forward_backward( model=self.model, + cfg=self.cfg, + data_iterator=data_iterator, num_microbatches=num_microbatches, seq_length=padded_seq_length, - micro_batch_size=mbs, - decoder_seq_length=padded_seq_length, + mbs=micro_batch_size, + post_processing_fn=loss_post_processor, forward_only=eval_mode, + defer_fp32_logits=self.defer_fp32_logits, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, do_not_average_loss=True, + straggler_timer=self.mcore_state.straggler_timer, ) # Empty unused memory. @@ -461,19 +400,14 @@ def train( loss_metrics["global_valid_toks"] = global_valid_toks.item() mb_losses.append(loss_metrics["loss"]) - torch.distributed.broadcast_object_list( - [gb_loss_metrics], - src=get_pipeline_model_parallel_last_rank(), - group=get_pipeline_model_parallel_group(), - ) else: - loss_metrics = [None] # type: ignore - torch.distributed.broadcast_object_list( - loss_metrics, - src=get_pipeline_model_parallel_last_rank(), - group=get_pipeline_model_parallel_group(), - ) - gb_loss_metrics = loss_metrics[0] + gb_loss_metrics = None + + # Broadcast loss metrics from last stage to all stages + gb_loss_metrics = broadcast_loss_metrics_from_last_stage( + gb_loss_metrics + ) + if not parallel_state.is_pipeline_last_stage(ignore_virtual=True): mb_losses = [x["loss"] for x in gb_loss_metrics] all_mb_metrics.extend(gb_loss_metrics) @@ -486,25 +420,18 @@ def train( self.scheduler.step(increment=gbs) # Aggregate metrics across all microbatches - mb_metrics = defaultdict(list) - for m in all_mb_metrics: - for k, v in m.items(): - mb_metrics[k].append(v) - - with torch.no_grad(): - global_loss = torch.tensor(losses, device="cuda") - torch.distributed.all_reduce( - global_loss, - op=torch.distributed.ReduceOp.SUM, - group=parallel_state.get_data_parallel_group(), - ) + mb_metrics, global_loss = aggregate_training_statistics( + all_mb_metrics=all_mb_metrics, + losses=losses, + data_parallel_group=parallel_state.get_data_parallel_group(), + ) metrics = { "global_loss": global_loss.cpu(), "rank": torch.distributed.get_rank(), "gpu_name": torch.cuda.get_device_name(), "model_dtype": self.dtype, - "all_mb_metrics": dict(mb_metrics), + "all_mb_metrics": mb_metrics, "grad_norm": torch.tensor([grad_norm]), } # Collect MoE aux metrics averaged across microbatches @@ -561,97 +488,19 @@ def get_logprobs( straggler_timer=self.mcore_state.straggler_timer, ) - def forward_step_fn( - data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel - ): - processed_mb = next(data_iterator) - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - unpacked_input_ids = data_dict["input_ids"] - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - # Mamba models currently do not support packed_seq_params - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - if self.defer_fp32_logits: - additional_kwargs["fp32_output"] = False - - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **multimodal_data, - **additional_kwargs, - ) - - # Apply temperature scaling to logits for training - # This matches the dtensor worker's _apply_temperature_scaling in the train method - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) - - def collection_fn(output_tensor): - stc = time.time() - tp_grp = get_tensor_model_parallel_group() - tp_rank = get_tensor_model_parallel_rank() - logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) - if self.cfg["sequence_packing"]["enabled"]: - token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( - output_tensor, - target=input_ids, - cu_seqlens_padded=cu_seqlens_padded, - unpacked_seqlen=seq_length, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - group=tp_grp, - inference_only=True, - cp_group=get_context_parallel_group(), - chunk_size=logprob_chunk_size, - ) - else: - token_logprobs = from_parallel_logits_to_logprobs( - output_tensor, - target=unpacked_input_ids, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - tp_group=tp_grp, - inference_only=True, - chunk_size=logprob_chunk_size, - ) - - # Prepend 0 logprob for first token to maintain same sequence length as input - token_logprobs = torch.cat( - [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 - ) - return torch.tensor(0.0, device=token_logprobs.device), { - "logprobs": token_logprobs - } - - return output_tensor, collection_fn - - forward_backward_func = get_forward_backward_func() - list_of_logprobs = forward_backward_func( - forward_step_func=forward_step_fn, - data_iterator=mb_iterator, + list_of_logprobs = megatron_forward_backward( model=self.model, - num_microbatches=num_microbatches, + cfg=self.cfg, + data_iterator=mb_iterator, seq_length=padded_seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=padded_seq_length, + mbs=micro_batch_size, + num_microbatches=num_microbatches, + post_processing_fn=LogprobsPostProcessor(cfg=self.cfg), forward_only=True, + defer_fp32_logits=self.defer_fp32_logits, + straggler_timer=self.mcore_state.straggler_timer, ) + if is_pipeline_last_stage(ignore_virtual=True): all_log_probs_padded = [] all_logprobs = [l["logprobs"] for l in list_of_logprobs] @@ -664,12 +513,10 @@ def collection_fn(output_tensor): all_log_probs_padded.append(lp) logprobs = torch.cat(all_log_probs_padded, dim=0) - # broadcast logprobs to first pp rank - broadcast_tensor(logprobs, torch.distributed.get_rank(), pp_grp) + tensors = {"logprobs": logprobs} else: - logprobs = broadcast_tensor( - None, get_pipeline_model_parallel_last_rank(), pp_grp - ) + tensors = {"logprobs": None} + logprobs = broadcast_tensors_from_last_stage(tensors)["logprobs"] no_grad.__exit__(None, None, None) return BatchedDataDict[LogprobOutputSpec](logprobs=logprobs).to("cpu") @@ -768,172 +615,17 @@ def get_topk_logits( straggler_timer=self.mcore_state.straggler_timer, ) - def forward_step_fn( - data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel - ): - processed_mb = next(data_iterator) - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - unpacked_input_ids = data_dict["input_ids"] - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids_cp_sharded.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **additional_kwargs, - **multimodal_data, - ) - - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) - - def collection_fn(_): - # Only the last PP stage produces final logits/top-k; earlier stages return empty - # if not is_pipeline_last_stage(ignore_virtual=True): - # return output_tensor.new_zeros(()), {} - - tp_grp = get_tensor_model_parallel_group() - tp_rank = get_tensor_model_parallel_rank() - vocab_shard_size = output_tensor.shape[-1] - vocab_start_index = tp_rank * vocab_shard_size - - chunk_size = None - if "logprob_chunk_size" in self.cfg: - chunk_size = self.cfg["logprob_chunk_size"] - - topk_vals_local, topk_idx_local = distributed_vocab_topk( - output_tensor, - k, - tp_grp, - vocab_start_index=vocab_start_index, - vocab_end_index=vocab_start_index + vocab_shard_size, - chunk_size=chunk_size, - ) - - if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: - cp_grp = get_context_parallel_group() - if self.cfg["sequence_packing"]["enabled"]: - cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] - # Per-sequence CP allgather following packed-sequence logic - batch_size = data_dict["input_ids"].shape[0] - total_packed_len = int(cu_seqlens_padded[-1].item()) - - topk_vals_full = torch.zeros( - (1, total_packed_len, k), - dtype=topk_vals_local.dtype, - device=topk_vals_local.device, - ) - topk_idx_full = torch.zeros( - (1, total_packed_len, k), - dtype=topk_idx_local.dtype, - device=topk_idx_local.device, - ) - - for i in range(batch_size): - start_idx = int(cu_seqlens_padded[i].item()) - end_idx = int(cu_seqlens_padded[i + 1].item()) - if end_idx > start_idx: - local_vals_slice = topk_vals_local[ - :, start_idx // cp_size : end_idx // cp_size, : - ] - local_idx_slice = topk_idx_local[ - :, start_idx // cp_size : end_idx // cp_size, : - ] - gathered_vals = allgather_cp_sharded_tensor( - local_vals_slice, cp_grp, seq_dim=1 - ) - gathered_idx = allgather_cp_sharded_tensor( - local_idx_slice, cp_grp, seq_dim=1 - ) - # Some kernels may return [X, Y, k] where X*Y = (end_idx - start_idx). - # Flatten leading dims and reshape to [1, expected_len, k] to match target. - expected_len = end_idx - start_idx - if ( - gathered_vals.dim() == 3 - and gathered_vals.shape[1] != expected_len - ): - gathered_vals = gathered_vals.reshape( - 1, expected_len, gathered_vals.shape[-1] - ) - if ( - gathered_idx.dim() == 3 - and gathered_idx.shape[1] != expected_len - ): - gathered_idx = gathered_idx.reshape( - 1, expected_len, gathered_idx.shape[-1] - ) - topk_vals_full[:, start_idx:end_idx, :] = gathered_vals - topk_idx_full[:, start_idx:end_idx, :] = gathered_idx - else: - # Sequence packing must be enabled when CP > 1 - raise RuntimeError( - "Context Parallelism (CP>1) requires sequence packing to be enabled." - ) - else: - topk_vals_full = topk_vals_local - topk_idx_full = topk_idx_local - - if self.cfg["sequence_packing"]["enabled"]: - batch_size = data_dict["input_ids"].shape[0] - seq_lengths = data_dict["input_lengths"] - out_vals = torch.zeros( - (batch_size, seq_length, k), - dtype=topk_vals_full.dtype, - device=topk_vals_full.device, - ) - out_idx = torch.zeros( - (batch_size, seq_length, k), - dtype=topk_idx_full.dtype, - device=topk_idx_full.device, - ) - for i in range(batch_size): - seq_len = int(seq_lengths[i].item()) - start_idx = int(cu_seqlens_padded[i].item()) - if seq_len > 0: - out_vals[i, :seq_len, :] = topk_vals_full[ - 0, start_idx : start_idx + seq_len, : - ] - out_idx[i, :seq_len, :] = topk_idx_full[ - 0, start_idx : start_idx + seq_len, : - ] - return output_tensor.new_zeros(()), { - "topk_logits": out_vals, - "topk_indices": out_idx, - } - else: - return output_tensor.new_zeros(()), { - "topk_logits": topk_vals_full, - "topk_indices": topk_idx_full, - } - - return output_tensor, collection_fn - - forward_backward_func = get_forward_backward_func() - list_of_outputs = forward_backward_func( - forward_step_func=forward_step_fn, - data_iterator=mb_iterator, + list_of_outputs = megatron_forward_backward( model=self.model, - num_microbatches=num_microbatches, + cfg=self.cfg, + data_iterator=mb_iterator, seq_length=padded_seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=padded_seq_length, + mbs=micro_batch_size, + num_microbatches=num_microbatches, + post_processing_fn=TopkLogitsPostProcessor(cfg=self.cfg, k=k), forward_only=True, + defer_fp32_logits=self.defer_fp32_logits, + straggler_timer=self.mcore_state.straggler_timer, ) if is_pipeline_last_stage(ignore_virtual=True): @@ -952,16 +644,20 @@ def collection_fn(_): topk_logits = torch.cat(logits_chunks, dim=0) topk_indices = torch.cat(indices_chunks, dim=0) - topk_logits = broadcast_tensor( - topk_logits, torch.distributed.get_rank(), pp_grp - ) - topk_indices = broadcast_tensor( - topk_indices, torch.distributed.get_rank(), pp_grp - ) + tensors_to_broadcast = { + "topk_logits": topk_logits, + "topk_indices": topk_indices, + } else: - last_pp_rank = get_pipeline_model_parallel_last_rank() - topk_logits = broadcast_tensor(None, last_pp_rank, pp_grp) - topk_indices = broadcast_tensor(None, last_pp_rank, pp_grp) + tensors_to_broadcast = { + "topk_logits": None, + "topk_indices": None, + } + + # Broadcast tensors from last stage to all stages + broadcasted = broadcast_tensors_from_last_stage(tensors_to_broadcast) + topk_logits = broadcasted["topk_logits"] + topk_indices = broadcasted["topk_indices"] no_grad.__exit__(None, None, None) return BatchedDataDict.from_batches( @@ -1254,7 +950,7 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) # Broadcast size_in_bytes across pipeline parallel ranks - return broadcast_object_across_pp_ranks(size_in_bytes) + return broadcast_obj_from_pp_rank(size_in_bytes) for task in self.refit_conversion_tasks: param_info.append( diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index 2982d1b0a8..f0ce832eb0 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -14,6 +14,7 @@ """Test script to debug high gradients with sequence packing + context parallelism.""" import os +from unittest.mock import MagicMock import pytest import ray @@ -41,13 +42,14 @@ def __init__(self, cp_size): def test_sequence_packing_gradients(self): from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank - from nemo_rl.models.megatron.common import ( - forward_step_arbitrary_loss, - ) from nemo_rl.models.megatron.data import ( _pack_sequences_for_megatron, make_processed_microbatch_iterator, ) + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + forward_with_post_processing_fn, + ) # Initialize process group torch.distributed.init_process_group(backend="nccl") @@ -289,11 +291,18 @@ def make_packed_logits(logits): packed_grad, baseline_grad_store, atol=1e-5, rtol=1e-5 ) - # test 3: with forward_step_arbitrary_loss + # test 3: with forward_with_post_processing_fn # reset grad baseline_logits.grad.zero_() packed_logits = make_packed_logits(baseline_logits) + # mock straggler detector with dummy context manager + mock_straggler_timer = MagicMock() + mock_straggler_timer.return_value = MagicMock( + __enter__=MagicMock(return_value=None), + __exit__=MagicMock(return_value=False), + ) + # mock model forward class MockModel: def __init__(self): @@ -307,51 +316,39 @@ def forward( ): return self.logits - class MockMcoreState: - def __init__(self): - # context that does nothing, but supports both with straggler_timer and with straggler_timer(bdata=True) - from contextlib import nullcontext - - class DummyStragglerTimer: - def __call__(self, *args, **kwargs): - return nullcontext() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - self.straggler_timer = DummyStragglerTimer() + cfg = { + "sequence_packing": {"enabled": True}, + "dynamic_batching": {"enabled": False}, + "megatron_cfg": { + "tensor_model_parallel_size": 1, + "sequence_parallel": False, + "pipeline_model_parallel_size": 1, + "context_parallel_size": cp_size, + }, + } - mock_mcore_state = MockMcoreState() + post_processor = LossPostProcessor( + loss_fn=base_loss_fn, + cfg=cfg, + cp_normalize=True, + ) - output_tensor, wrapped_loss_fn = forward_step_arbitrary_loss( - mock_mcore_state, - global_valid_seqs, - global_valid_toks, + output_tensor, wrapped_loss_fn = forward_with_post_processing_fn( data_iterator=make_processed_microbatch_iterator( iter([packed_data_dict]), - cfg={ - "sequence_packing": {"enabled": True}, - "dynamic_batching": {"enabled": False}, - "megatron_cfg": { - "tensor_model_parallel_size": 1, - "sequence_parallel": False, - "pipeline_model_parallel_size": 1, - "context_parallel_size": cp_size, - }, - }, + cfg=cfg, seq_length_key="input_lengths", pad_individual_seqs_to_multiple_of=pad_to_multiple, pad_packed_seq_to_multiple_of=1, - straggler_timer=mock_mcore_state.straggler_timer, + straggler_timer=mock_straggler_timer, pad_full_seq_to=max_seq_len * batch_size if cp_size > 1 else None, ), model=MockModel(), - loss_fn=base_loss_fn, - pack_sequences=True, - cp_normalize=True, + cfg=cfg, + post_processing_fn=post_processor, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + straggler_timer=mock_straggler_timer, ) loss, metrics = wrapped_loss_fn(output_tensor) diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py new file mode 100644 index 0000000000..cf261c3d75 --- /dev/null +++ b/tests/unit/models/megatron/test_train.py @@ -0,0 +1,1242 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for Megatron training utilities. + +This module tests the training functions in nemo_rl.models.megatron.train, +focusing on: +- Model forward pass +- Forward with post-processing +- Loss/logprobs/topk post-processors +""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + + +class TestModelForward: + """Tests for model_forward function.""" + + def test_model_forward_basic(self): + """Test basic model_forward without multimodal data.""" + from nemo_rl.models.megatron.train import model_forward + + # Setup mocks + mock_model = MagicMock() + mock_output = torch.randn(2, 10, 100) + mock_model.return_value = mock_output + + mock_data_dict = MagicMock() + mock_data_dict.get_multimodal_dict.return_value = {} + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) + position_ids = torch.tensor([[0, 1, 2], [0, 1, 2]]) + attention_mask = torch.ones(2, 3) + cfg = {} + + result = model_forward( + model=mock_model, + data_dict=mock_data_dict, + cfg=cfg, + input_ids_cp_sharded=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + ) + + assert torch.equal(result, mock_output) + mock_model.assert_called_once() + + def test_model_forward_with_straggler_timer(self): + """Test model_forward uses straggler_timer context manager when provided.""" + from nemo_rl.models.megatron.train import model_forward + + mock_model = MagicMock() + mock_output = torch.randn(1, 10, 100) + mock_model.return_value = mock_output + + mock_data_dict = MagicMock() + mock_data_dict.get_multimodal_dict.return_value = {} + + mock_timer = MagicMock() + mock_ctx = MagicMock() + mock_timer.return_value = mock_ctx + + result = model_forward( + model=mock_model, + data_dict=mock_data_dict, + cfg={}, + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + position_ids=torch.tensor([[0, 1, 2]]), + attention_mask=torch.ones(1, 3), + straggler_timer=mock_timer, + ) + + # Verify straggler_timer was called as a context manager + mock_timer.assert_called_once() + mock_ctx.__enter__.assert_called_once() + mock_ctx.__exit__.assert_called_once() + assert torch.equal(result, mock_output) + + def test_model_forward_with_packed_seq_params(self): + """Test model_forward passes packed_seq_params to model.""" + from nemo_rl.models.megatron.train import model_forward + + mock_model = MagicMock() + mock_model.return_value = torch.randn(1, 10, 100) + + mock_data_dict = MagicMock() + mock_data_dict.get_multimodal_dict.return_value = {} + + mock_packed_seq_params = MagicMock() + + model_forward( + model=mock_model, + data_dict=mock_data_dict, + cfg={}, + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + position_ids=torch.tensor([[0, 1, 2]]), + attention_mask=torch.ones(1, 3), + packed_seq_params=mock_packed_seq_params, + ) + + # Verify packed_seq_params was passed + call_kwargs = mock_model.call_args[1] + assert call_kwargs["packed_seq_params"] == mock_packed_seq_params + + def test_model_forward_with_defer_fp32_logits(self): + """Test model_forward passes fp32_output when defer_fp32_logits is True.""" + from nemo_rl.models.megatron.train import model_forward + + mock_model = MagicMock() + mock_model.return_value = torch.randn(1, 10, 100) + + mock_data_dict = MagicMock() + mock_data_dict.get_multimodal_dict.return_value = {} + + model_forward( + model=mock_model, + data_dict=mock_data_dict, + cfg={}, + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + position_ids=torch.tensor([[0, 1, 2]]), + attention_mask=torch.ones(1, 3), + defer_fp32_logits=True, + ) + + call_kwargs = mock_model.call_args[1] + assert call_kwargs["fp32_output"] is False + + def test_model_forward_clears_position_ids_for_multimodal(self): + """Test model_forward sets position_ids to None for multimodal data.""" + from nemo_rl.models.megatron.train import model_forward + + mock_model = MagicMock() + mock_model.return_value = torch.randn(1, 10, 100) + + mock_data_dict = MagicMock() + mock_data_dict.get_multimodal_dict.return_value = { + "images": torch.randn(1, 3, 224, 224) + } + + model_forward( + model=mock_model, + data_dict=mock_data_dict, + cfg={}, + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + position_ids=torch.tensor([[0, 1, 2]]), + attention_mask=torch.ones(1, 3), + ) + + call_kwargs = mock_model.call_args[1] + assert call_kwargs["position_ids"] is None + + +class TestApplyTemperatureScaling: + """Tests for apply_temperature_scaling function.""" + + def test_temperature_scaling_with_generation_config(self): + """Test that logits are divided by the configured temperature.""" + from nemo_rl.models.megatron.train import apply_temperature_scaling + + logits = torch.ones(2, 10, 100) * 4.0 + cfg = {"generation": {"temperature": 2.0}} + + result = apply_temperature_scaling(logits, cfg) + + # 4.0 / 2.0 = 2.0 + assert torch.allclose(result, torch.ones_like(result) * 2.0) + # Verify in-place: result is the same tensor + assert result.data_ptr() == logits.data_ptr() + + def test_temperature_scaling_no_generation_key(self): + """Test that logits are unchanged when 'generation' key is absent.""" + from nemo_rl.models.megatron.train import apply_temperature_scaling + + logits = torch.ones(2, 10, 100) * 3.0 + cfg = {} + + result = apply_temperature_scaling(logits, cfg) + + assert torch.allclose(result, torch.ones_like(result) * 3.0) + + def test_temperature_scaling_generation_is_none(self): + """Test that logits are unchanged when generation config is None.""" + from nemo_rl.models.megatron.train import apply_temperature_scaling + + logits = torch.ones(2, 10, 100) * 3.0 + cfg = {"generation": None} + + result = apply_temperature_scaling(logits, cfg) + + assert torch.allclose(result, torch.ones_like(result) * 3.0) + + def test_temperature_scaling_with_temperature_one(self): + """Test that temperature=1.0 leaves logits unchanged.""" + from nemo_rl.models.megatron.train import apply_temperature_scaling + + logits = torch.randn(2, 10, 100) + original = logits.clone() + cfg = {"generation": {"temperature": 1.0}} + + result = apply_temperature_scaling(logits, cfg) + + assert torch.allclose(result, original) + + +class TestForwardWithPostProcessingFn: + """Tests for forward_with_post_processing_fn function.""" + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=1 + ) + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_loss_post_processor( + self, mock_model_forward, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank + ): + """Test forward with LossPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + forward_with_post_processing_fn, + ) + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + mock_model_forward.return_value = torch.randn(2, 10, 100) + + # Create processed microbatch + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + data_iterator = iter([processed_mb]) + mock_model = MagicMock() + cfg = {"sequence_packing": {"enabled": False}} + + mock_loss_fn = MagicMock() + post_processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg) + + output, wrapped_fn = forward_with_post_processing_fn( + data_iterator=data_iterator, + model=mock_model, + cfg=cfg, + post_processing_fn=post_processor, + ) + + mock_model_forward.assert_called_once() + + # forward_with_post_processing_fn should return a callable + assert callable(wrapped_fn) + assert isinstance(output, torch.Tensor) + + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_logprobs_post_processor(self, mock_model_forward): + """Test forward with LogprobsPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LogprobsPostProcessor, + forward_with_post_processing_fn, + ) + + mock_model_forward.return_value = torch.randn(2, 10, 100) + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + data_iterator = iter([processed_mb]) + cfg = {"sequence_packing": {"enabled": False}} + post_processor = LogprobsPostProcessor(cfg=cfg) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + output, wrapped_fn = forward_with_post_processing_fn( + data_iterator=data_iterator, + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + ) + + mock_model_forward.assert_called_once() + + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_topk_post_processor(self, mock_model_forward): + """Test forward with TopkLogitsPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + TopkLogitsPostProcessor, + forward_with_post_processing_fn, + ) + + mock_model_forward.return_value = torch.randn(2, 10, 100) + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + data_iterator = iter([processed_mb]) + cfg = { + "sequence_packing": {"enabled": False}, + "megatron_cfg": {"context_parallel_size": 1}, + } + post_processor = TopkLogitsPostProcessor(cfg=cfg, k=5) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + output, wrapped_fn = forward_with_post_processing_fn( + data_iterator=data_iterator, + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + ) + + mock_model_forward.assert_called_once() + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=1 + ) + @patch("nemo_rl.models.megatron.train.model_forward") + @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") + def test_forward_applies_temperature_scaling_for_loss( + self, + mock_temp_scaling, + mock_model_forward, + mock_cp_size, + mock_cp_grp, + mock_tp_grp, + mock_tp_rank, + ): + """Test that forward_with_post_processing_fn applies temperature scaling for LossPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + forward_with_post_processing_fn, + ) + + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + output_tensor = torch.randn(2, 10, 100) + mock_model_forward.return_value = output_tensor + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + cfg = { + "sequence_packing": {"enabled": False}, + "generation": {"temperature": 0.7}, + } + post_processor = LossPostProcessor(loss_fn=MagicMock(), cfg=cfg) + + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + ) + + # Verify apply_temperature_scaling was called with the output tensor and cfg + mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + + @patch("nemo_rl.models.megatron.train.model_forward") + @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") + def test_forward_applies_temperature_scaling_for_logprobs( + self, mock_temp_scaling, mock_model_forward + ): + """Test that forward_with_post_processing_fn applies temperature scaling for LogprobsPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LogprobsPostProcessor, + forward_with_post_processing_fn, + ) + + output_tensor = torch.randn(2, 10, 100) + mock_model_forward.return_value = output_tensor + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + cfg = { + "sequence_packing": {"enabled": False}, + "generation": {"temperature": 0.5}, + } + post_processor = LogprobsPostProcessor(cfg=cfg) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + ) + + mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + + @patch("nemo_rl.models.megatron.train.model_forward") + @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") + def test_forward_applies_temperature_scaling_for_topk( + self, mock_temp_scaling, mock_model_forward + ): + """Test that forward_with_post_processing_fn applies temperature scaling for TopkLogitsPostProcessor.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + TopkLogitsPostProcessor, + forward_with_post_processing_fn, + ) + + output_tensor = torch.randn(2, 10, 100) + mock_model_forward.return_value = output_tensor + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + cfg = { + "sequence_packing": {"enabled": False}, + "megatron_cfg": {"context_parallel_size": 1}, + "generation": {"temperature": 1.5}, + } + post_processor = TopkLogitsPostProcessor(cfg=cfg, k=5) + + with patch.object(post_processor, "__call__", return_value=MagicMock()): + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + ) + + mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + + @patch("nemo_rl.models.megatron.train.model_forward") + @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") + def test_forward_does_not_apply_temperature_scaling_for_unknown_type( + self, mock_temp_scaling, mock_model_forward + ): + """Test that temperature scaling is NOT applied for unknown post-processor types (before they raise).""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import forward_with_post_processing_fn + + mock_model_forward.return_value = torch.randn(2, 10, 100) + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=None, + position_ids=None, + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + with pytest.raises(TypeError): + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + cfg={"generation": {"temperature": 2.0}}, + post_processing_fn="not_a_processor", + ) + + mock_temp_scaling.assert_not_called() + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=1 + ) + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_straggler_timer( + self, mock_model_forward, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank + ): + """Test that straggler_timer is passed through to model_forward.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + forward_with_post_processing_fn, + ) + + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + mock_model_forward.return_value = torch.randn(2, 10, 100) + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=torch.ones(1, 3), + position_ids=torch.tensor([[0, 1, 2]]), + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + cfg = {"sequence_packing": {"enabled": False}} + post_processor = LossPostProcessor(loss_fn=MagicMock(), cfg=cfg) + mock_timer = MagicMock() + + forward_with_post_processing_fn( + data_iterator=iter([processed_mb]), + model=MagicMock(), + cfg=cfg, + post_processing_fn=post_processor, + straggler_timer=mock_timer, + ) + + # Verify straggler_timer was passed to model_forward + call_kwargs = mock_model_forward.call_args[1] + assert call_kwargs["straggler_timer"] is mock_timer + + @patch("nemo_rl.models.megatron.train.model_forward") + def test_forward_with_unknown_post_processor_raises(self, mock_model_forward): + """Test that unknown post-processor type raises TypeError.""" + from nemo_rl.models.megatron.data import ProcessedMicrobatch + from nemo_rl.models.megatron.train import forward_with_post_processing_fn + + mock_model_forward.return_value = torch.randn(2, 10, 100) + + processed_mb = ProcessedMicrobatch( + data_dict=MagicMock(), + input_ids=torch.tensor([[1, 2, 3]]), + input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), + attention_mask=None, + position_ids=None, + packed_seq_params=None, + cu_seqlens_padded=None, + ) + + data_iterator = iter([processed_mb]) + unknown_processor = "not_a_processor" + + with pytest.raises(TypeError, match="Unknown post-processing function type"): + forward_with_post_processing_fn( + data_iterator=data_iterator, + model=MagicMock(), + cfg={}, + post_processing_fn=unknown_processor, + ) + + +class TestMegatronForwardBackward: + """Tests for megatron_forward_backward function.""" + + @patch("nemo_rl.models.megatron.train.get_forward_backward_func") + def test_megatron_forward_backward_calls_forward_backward_func(self, mock_get_fb): + """Test that megatron_forward_backward calls the forward_backward_func.""" + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + megatron_forward_backward, + ) + + mock_fb_func = MagicMock(return_value={"loss": torch.tensor(0.5)}) + mock_get_fb.return_value = mock_fb_func + + mock_model = MagicMock() + mock_loss_fn = MagicMock() + cfg = {"sequence_packing": {"enabled": False}} + post_processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg) + + result = megatron_forward_backward( + model=mock_model, + cfg=cfg, + data_iterator=iter([]), + num_microbatches=4, + seq_length=128, + mbs=2, + post_processing_fn=post_processor, + ) + + mock_get_fb.assert_called_once() + mock_fb_func.assert_called_once() + + # Verify key arguments + call_kwargs = mock_fb_func.call_args[1] + assert call_kwargs["num_microbatches"] == 4 + assert call_kwargs["seq_length"] == 128 + assert call_kwargs["micro_batch_size"] == 2 + + @patch("nemo_rl.models.megatron.train.get_forward_backward_func") + def test_megatron_forward_backward_forward_only(self, mock_get_fb): + """Test megatron_forward_backward with forward_only=True.""" + from nemo_rl.models.megatron.train import ( + LossPostProcessor, + megatron_forward_backward, + ) + + mock_fb_func = MagicMock() + mock_get_fb.return_value = mock_fb_func + + cfg = {"sequence_packing": {"enabled": False}} + post_processor = LossPostProcessor(loss_fn=MagicMock(), cfg=cfg) + + megatron_forward_backward( + model=MagicMock(), + cfg=cfg, + data_iterator=iter([]), + num_microbatches=1, + seq_length=64, + mbs=1, + post_processing_fn=post_processor, + forward_only=True, + ) + + call_kwargs = mock_fb_func.call_args[1] + assert call_kwargs["forward_only"] is True + + +class TestLossPostProcessor: + """Tests for LossPostProcessor class.""" + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=1 + ) + def test_loss_post_processor_no_packing( + self, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank + ): + """Test LossPostProcessor without sequence packing.""" + from nemo_rl.models.megatron.train import LossPostProcessor + + mock_loss_fn = MagicMock(return_value=(torch.tensor(0.5), {"loss": 0.5})) + cfg = {"sequence_packing": {"enabled": False}} + + processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=False) + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + wrapped_fn = processor( + data_dict=MagicMock(), + packed_seq_params=None, + global_valid_seqs=torch.tensor(10), + global_valid_toks=torch.tensor(100), + ) + + # Call the wrapped function + output_tensor = torch.randn(2, 10, 100) + loss, metrics = wrapped_fn(output_tensor) + + assert torch.isclose(loss, torch.tensor(0.5)) + assert isinstance(metrics, dict) + assert len(metrics) == 1 and metrics["loss"] == 0.5 + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=2 + ) + def test_loss_post_processor_with_cp_normalize( + self, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank + ): + """Test LossPostProcessor with CP normalization.""" + from nemo_rl.models.megatron.train import LossPostProcessor + + mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {})) + cfg = {"sequence_packing": {"enabled": False}} + + processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=True) + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + wrapped_fn = processor(data_dict=MagicMock()) + + output_tensor = torch.randn(2, 10, 100) + loss, _ = wrapped_fn(output_tensor) + + # Loss should be divided by CP size (2) + assert torch.isclose(loss, torch.tensor(0.5)) + + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_context_parallel_world_size", return_value=1 + ) + @patch("nemo_rl.models.megatron.train.SequencePackingLossWrapper") + def test_loss_post_processor_with_packing( + self, mock_wrapper, mock_cp_size, mock_cp_grp, mock_tp_grp, mock_tp_rank + ): + """Test LossPostProcessor with sequence packing.""" + from nemo_rl.models.megatron.train import LossPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + mock_loss_fn = MagicMock() + cfg = {"sequence_packing": {"enabled": True}} + + mock_packed_seq_params = MagicMock() + mock_packed_seq_params.cu_seqlens_q = torch.tensor([0, 5, 10]) + mock_packed_seq_params.cu_seqlens_q_padded = torch.tensor([0, 8, 16]) + + processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=False) + + processor(data_dict=MagicMock(), packed_seq_params=mock_packed_seq_params) + + # Verify SequencePackingLossWrapper was called + mock_wrapper.assert_called_once() + + +class TestLogprobsPostProcessor: + """Tests for LogprobsPostProcessor class.""" + + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.from_parallel_logits_to_logprobs") + def test_logprobs_post_processor_no_packing( + self, mock_from_logits, mock_tp_rank, mock_tp_grp + ): + """Test LogprobsPostProcessor without sequence packing.""" + from nemo_rl.models.megatron.train import LogprobsPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + + cfg = {"sequence_packing": {"enabled": False}} + processor = LogprobsPostProcessor(cfg=cfg) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + return_value=torch.tensor([[1, 2, 3, 4, 5]]) + ) + + mock_logprobs = torch.randn(1, 4) # One less than input length + mock_from_logits.return_value = mock_logprobs + + wrapped_fn = processor( + data_dict=mock_data_dict, + input_ids=torch.tensor([[1, 2, 3, 4, 5]]), + cu_seqlens_padded=None, + ) + + output_tensor = torch.randn(1, 5, 100) + loss, result = wrapped_fn(output_tensor) + + # Loss should be 0 + assert loss.item() == 0.0 + # Result should have logprobs key + assert "logprobs" in result + # Logprobs should be prepended with a 0 + assert result["logprobs"].shape[1] == 5 + + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch( + "nemo_rl.models.megatron.train.from_parallel_logits_to_logprobs_packed_sequences" + ) + def test_logprobs_post_processor_with_packing( + self, mock_from_logits_packed, mock_cp_grp, mock_tp_rank, mock_tp_grp + ): + """Test LogprobsPostProcessor with sequence packing.""" + from nemo_rl.models.megatron.train import LogprobsPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + cfg = {"sequence_packing": {"enabled": True}} + processor = LogprobsPostProcessor(cfg=cfg) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + return_value=torch.tensor([[1, 2, 3, 4, 5]]) + ) + + mock_logprobs = torch.randn(1, 4) + mock_from_logits_packed.return_value = mock_logprobs + + wrapped_fn = processor( + data_dict=mock_data_dict, + input_ids=torch.tensor([[1, 2, 3, 4, 5]]), + cu_seqlens_padded=torch.tensor([0, 5]), + ) + + output_tensor = torch.randn(1, 5, 100) + loss, result = wrapped_fn(output_tensor) + + mock_from_logits_packed.assert_called_once() + assert "logprobs" in result + + +class TestTopkLogitsPostProcessor: + """Tests for TopkLogitsPostProcessor class.""" + + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.distributed_vocab_topk") + def test_topk_post_processor_no_packing(self, mock_topk, mock_tp_rank, mock_tp_grp): + """Test TopkLogitsPostProcessor without sequence packing.""" + from nemo_rl.models.megatron.train import TopkLogitsPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + + cfg = { + "sequence_packing": {"enabled": False}, + "megatron_cfg": {"context_parallel_size": 1}, + } + k = 5 + processor = TopkLogitsPostProcessor(cfg=cfg, k=k) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + side_effect=lambda key: torch.tensor([[1, 2, 3, 4, 5]]) + if key == "input_ids" + else torch.tensor([5]) + ) + + mock_topk_vals = torch.randn(1, 5, k) + mock_topk_idx = torch.randint(0, 100, (1, 5, k)) + mock_topk.return_value = (mock_topk_vals, mock_topk_idx) + + wrapped_fn = processor( + data_dict=mock_data_dict, + cu_seqlens_padded=None, + ) + + output_tensor = torch.randn(1, 5, 100) + loss, result = wrapped_fn(output_tensor) + + assert "topk_logits" in result + assert "topk_indices" in result + assert result["topk_logits"].shape[-1] == k + + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.distributed_vocab_topk") + def test_topk_post_processor_with_packing( + self, mock_topk, mock_tp_rank, mock_tp_grp + ): + """Test TopkLogitsPostProcessor with sequence packing.""" + from nemo_rl.models.megatron.train import TopkLogitsPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + + cfg = { + "sequence_packing": {"enabled": True}, + "megatron_cfg": {"context_parallel_size": 1}, + } + k = 3 + processor = TopkLogitsPostProcessor(cfg=cfg, k=k) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + side_effect=lambda key: torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]) + if key == "input_ids" + else torch.tensor([5]) + ) + + mock_topk_vals = torch.randn(1, 8, k) + mock_topk_idx = torch.randint(0, 100, (1, 8, k)) + mock_topk.return_value = (mock_topk_vals, mock_topk_idx) + + cu_seqlens_padded = torch.tensor([0, 5]) + + wrapped_fn = processor( + data_dict=mock_data_dict, + cu_seqlens_padded=cu_seqlens_padded, + ) + + output_tensor = torch.randn(1, 8, 100) + loss, result = wrapped_fn(output_tensor) + + assert "topk_logits" in result + assert "topk_indices" in result + # Output should be unpacked to batch shape + assert result["topk_logits"].shape[0] == 1 + + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.distributed_vocab_topk") + def test_topk_cp_without_packing_raises( + self, mock_topk, mock_tp_rank, mock_tp_grp, mock_cp_grp + ): + """Test that CP > 1 without packing raises RuntimeError.""" + from nemo_rl.models.megatron.train import TopkLogitsPostProcessor + + # Set up mock return values for process groups + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + cfg = { + "sequence_packing": {"enabled": False}, + "megatron_cfg": {"context_parallel_size": 2}, + } + processor = TopkLogitsPostProcessor(cfg=cfg, k=5) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + side_effect=lambda key: torch.tensor([[1, 2, 3]]) + if key == "input_ids" + else torch.tensor([3]) + ) + + mock_topk.return_value = ( + torch.randn(1, 3, 5), + torch.randint(0, 100, (1, 3, 5)), + ) + + wrapped_fn = processor(data_dict=mock_data_dict, cu_seqlens_padded=None) + + output_tensor = torch.randn(1, 3, 100) + + with pytest.raises( + RuntimeError, match="Context Parallelism.*requires sequence packing" + ): + wrapped_fn(output_tensor) + + @patch("nemo_rl.models.megatron.train.allgather_cp_sharded_tensor") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.distributed_vocab_topk") + def test_topk_cp_with_packing_single_sequence( + self, mock_topk, mock_tp_rank, mock_tp_grp, mock_cp_grp, mock_allgather + ): + """Test TopkLogitsPostProcessor with CP > 1 and packing for a single sequence.""" + from nemo_rl.models.megatron.train import TopkLogitsPostProcessor + + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + cp_size = 2 + k = 3 + seq_len = 8 # Total packed length + local_seq_len = seq_len // cp_size # Each CP rank sees half + + cfg = { + "sequence_packing": {"enabled": True}, + "megatron_cfg": {"context_parallel_size": cp_size}, + } + processor = TopkLogitsPostProcessor(cfg=cfg, k=k) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + side_effect=lambda key: torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + if key == "input_ids" + else torch.tensor([8]) + ) + + # distributed_vocab_topk returns local (CP-sharded) results + mock_topk_vals = torch.randn(1, local_seq_len, k) + mock_topk_idx = torch.randint(0, 100, (1, local_seq_len, k)) + mock_topk.return_value = (mock_topk_vals, mock_topk_idx) + + # allgather returns the full gathered tensor + gathered_vals = torch.randn(1, seq_len, k) + gathered_idx = torch.randint(0, 100, (1, seq_len, k)) + mock_allgather.side_effect = [gathered_vals, gathered_idx] + + cu_seqlens_padded = torch.tensor([0, seq_len]) + + wrapped_fn = processor( + data_dict=mock_data_dict, + cu_seqlens_padded=cu_seqlens_padded, + ) + + output_tensor = torch.randn(1, local_seq_len, 100) + loss, result = wrapped_fn(output_tensor) + + # Verify allgather was called for both vals and indices + assert mock_allgather.call_count == 2 + assert "topk_logits" in result + assert "topk_indices" in result + # Output should be unpacked: (batch_size=1, unpacked_seqlen=8, k=3) + assert result["topk_logits"].shape == (1, 8, k) + assert result["topk_indices"].shape == (1, 8, k) + + @patch("nemo_rl.models.megatron.train.allgather_cp_sharded_tensor") + @patch("nemo_rl.models.megatron.train.get_context_parallel_group") + @patch("nemo_rl.models.megatron.train.get_tensor_model_parallel_group") + @patch( + "nemo_rl.models.megatron.train.get_tensor_model_parallel_rank", return_value=0 + ) + @patch("nemo_rl.models.megatron.train.distributed_vocab_topk") + def test_topk_cp_with_packing_multiple_sequences( + self, mock_topk, mock_tp_rank, mock_tp_grp, mock_cp_grp, mock_allgather + ): + """Test TopkLogitsPostProcessor with CP > 1, packing, and multiple sequences in batch.""" + from nemo_rl.models.megatron.train import TopkLogitsPostProcessor + + mock_tp_grp.return_value = MagicMock() + mock_cp_grp.return_value = MagicMock() + + cp_size = 2 + k = 3 + # Two sequences packed: seq1 has 4 tokens, seq2 has 6 tokens => total packed = 10 + seq1_len = 4 + seq2_len = 6 + total_packed_len = seq1_len + seq2_len + local_packed_len = total_packed_len // cp_size + unpacked_seqlen = 6 # Max seq length in batch (for output shape) + + cfg = { + "sequence_packing": {"enabled": True}, + "megatron_cfg": {"context_parallel_size": cp_size}, + } + processor = TopkLogitsPostProcessor(cfg=cfg, k=k) + + mock_data_dict = MagicMock() + mock_data_dict.__getitem__ = MagicMock( + side_effect=lambda key: torch.zeros(2, unpacked_seqlen, dtype=torch.long) + if key == "input_ids" + else torch.tensor([seq1_len, seq2_len]) + ) + + # distributed_vocab_topk returns local (CP-sharded) results + mock_topk_vals = torch.randn(1, local_packed_len, k) + mock_topk_idx = torch.randint(0, 100, (1, local_packed_len, k)) + mock_topk.return_value = (mock_topk_vals, mock_topk_idx) + + # allgather is called once per sequence (2 sequences x 2 tensors = 4 calls) + def fake_allgather(local_tensor, group, seq_dim): + # Simulate gathering: double the seq_dim since cp_size=2 + return local_tensor.repeat(1, cp_size, 1) + + mock_allgather.side_effect = fake_allgather + + cu_seqlens_padded = torch.tensor([0, seq1_len, total_packed_len]) + + wrapped_fn = processor( + data_dict=mock_data_dict, + cu_seqlens_padded=cu_seqlens_padded, + ) + + output_tensor = torch.randn(1, local_packed_len, 100) + loss, result = wrapped_fn(output_tensor) + + # allgather called 2x per sequence (vals + idx) x 2 sequences = 4 calls + assert mock_allgather.call_count == 4 + assert "topk_logits" in result + assert "topk_indices" in result + # Output should be unpacked: (batch_size=2, unpacked_seqlen=6, k=3) + assert result["topk_logits"].shape == (2, unpacked_seqlen, k) + assert result["topk_indices"].shape == (2, unpacked_seqlen, k) + + +class TestAggregateTrainingStatistics: + """Tests for aggregate_training_statistics function.""" + + @patch("torch.distributed.all_reduce") + def test_aggregates_metrics_across_microbatches(self, mock_all_reduce): + """Test that per-microbatch metrics are collected into lists by key.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + all_mb_metrics = [ + {"loss": 0.5, "lr": 1e-4}, + {"loss": 0.3, "lr": 1e-4}, + {"loss": 0.2, "lr": 1e-4}, + ] + + mock_dp_group = MagicMock() + + mb_metrics, _ = aggregate_training_statistics( + all_mb_metrics=all_mb_metrics, + losses=[1.0], + data_parallel_group=mock_dp_group, + ) + + assert mb_metrics["loss"] == [0.5, 0.3, 0.2] + assert mb_metrics["lr"] == [1e-4, 1e-4, 1e-4] + assert len(mb_metrics) == 2 + + @patch("torch.distributed.all_reduce") + def test_returns_plain_dict(self, mock_all_reduce): + """Test that the returned mb_metrics is a plain dict, not defaultdict.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + mb_metrics, _ = aggregate_training_statistics( + all_mb_metrics=[{"loss": 0.5}], + losses=[1.0], + data_parallel_group=MagicMock(), + ) + + assert type(mb_metrics) is dict + + @patch("torch.distributed.all_reduce") + def test_global_loss_tensor_from_losses(self, mock_all_reduce): + """Test that losses list is converted to a CUDA tensor for all-reduce.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + mock_dp_group = MagicMock() + + _, global_loss = aggregate_training_statistics( + all_mb_metrics=[], + losses=[0.5, 0.3, 0.2], + data_parallel_group=mock_dp_group, + ) + + # Verify all_reduce was called with correct args + mock_all_reduce.assert_called_once() + call_args = mock_all_reduce.call_args + assert call_args[1]["op"] == torch.distributed.ReduceOp.SUM + assert call_args[1]["group"] is mock_dp_group + + # Verify tensor shape matches losses list + reduced_tensor = call_args[0][0] + assert reduced_tensor.shape == (3,) + + @patch("torch.distributed.all_reduce") + def test_empty_metrics(self, mock_all_reduce): + """Test with empty microbatch metrics list.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + mb_metrics, global_loss = aggregate_training_statistics( + all_mb_metrics=[], + losses=[1.0], + data_parallel_group=MagicMock(), + ) + + assert mb_metrics == {} + mock_all_reduce.assert_called_once() + + @patch("torch.distributed.all_reduce") + def test_handles_heterogeneous_metric_keys(self, mock_all_reduce): + """Test that microbatches with different metric keys are handled correctly.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + all_mb_metrics = [ + {"loss": 0.5, "lr": 1e-4}, + {"loss": 0.3, "global_valid_seqs": 8}, + ] + + mb_metrics, _ = aggregate_training_statistics( + all_mb_metrics=all_mb_metrics, + losses=[0.8], + data_parallel_group=MagicMock(), + ) + + assert mb_metrics["loss"] == [0.5, 0.3] + assert mb_metrics["lr"] == [1e-4] + assert mb_metrics["global_valid_seqs"] == [8] + + @patch("torch.distributed.all_reduce") + def test_no_grad_context(self, mock_all_reduce): + """Test that all-reduce runs under torch.no_grad context.""" + from nemo_rl.models.megatron.train import aggregate_training_statistics + + grad_enabled_during_all_reduce = [] + + def capture_grad_state(*args, **kwargs): + grad_enabled_during_all_reduce.append(torch.is_grad_enabled()) + + mock_all_reduce.side_effect = capture_grad_state + + aggregate_training_statistics( + all_mb_metrics=[], + losses=[1.0], + data_parallel_group=MagicMock(), + ) + + assert grad_enabled_during_all_reduce == [False]