diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index c464231aeb..7ff19739ea 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -322,6 +322,7 @@ logger: mlflow_enabled: false # Disable MLflow logging swanlab_enabled: false # Disable SwanLab logging monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + collect_worker_init_timing: false # If true, collects and saves worker initialization timing to JSON wandb: project: "grpo-dev" name: "grpo-dev-logger" diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f223fa091c..fc8a1e91e6 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -74,7 +74,7 @@ ) from nemo_rl.utils.memory_tracker import MemoryTracker from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import TimeoutChecker, Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer, save_worker_init_timing from nemo_rl.utils.venvs import create_local_venv_on_each_node # =============================================================================== @@ -655,6 +655,16 @@ def initialize_generation_with_policy( ray.get(futures_train + futures_inference) worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0 + # Collect worker initialization timing if enabled + if master_config["logger"].get("collect_worker_init_timing", False): + worker_groups = {"policy": policy.worker_group} + if policy_generation is not None: + worker_groups["vllm"] = policy_generation.worker_group + save_worker_init_timing( + worker_groups, + Path(master_config["logger"]["log_dir"]) / "worker_init_timing.json", + ) + # prepare refit info state_dict_info = policy.prepare_refit_info() if policy_generation is not None: diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index e4045183c2..29ae6be509 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -29,6 +29,7 @@ ) from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.worker_group_utils import recursive_merge_options +from nemo_rl.utils.timer import Timer from nemo_rl.utils.venvs import ( create_local_venv_on_each_node, ) @@ -1029,3 +1030,30 @@ def shutdown( self._worker_metadata = [] return success + + def collect_init_timing(self) -> dict[str, float]: + """Collect and aggregate initialization timing from all workers. + + Returns: + dict[str, float]: Dictionary mapping timing labels to aggregated max values. + Returns empty dict if workers don't have timing or on error. + """ + if not self._workers: + return {} + + # Collect timing from all workers + timing_futures = [] + for worker in self._workers: + if hasattr(worker, "get_init_timer"): + timing_futures.append(worker.get_init_timer.remote()) + + if not timing_futures: + return {} + + # Get all timers + timers = ray.get(timing_futures) + + # Aggregate using max across workers, sum within each worker + aggregated = Timer.aggregate_max(timers, reduction_op="sum") + + return aggregated diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 9238533cd2..edf787d607 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Measure module import time (import time first for measurement) +import time + +G_MODULE_IMPORT_START_TIME = time.perf_counter() + import copy import gc import os @@ -35,6 +40,10 @@ from nemo_rl.models.huggingface.common import ModelFlag from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.timer import Timer + +# Calculate module import duration after all imports +G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME # Use a base class to share some functions to avoid code duplication. @@ -142,6 +151,14 @@ def __init__( self.fraction_of_gpus = fraction_of_gpus self.is_model_owner = bundle_indices is not None + # Initialize timer for tracking initialization stages + self.init_timer = Timer() + self.init_timer.start("total_init") + + # Record module import time + if "G_MODULE_IMPORT_DURATION" in globals(): + self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION] + # Store the Python executable being used by this worker self.py_executable = sys.executable @@ -151,6 +168,7 @@ def __init__( self.tokenizer = None self.rank = 0 self.world_size = 1 + self.init_timer.stop("total_init") return # In Ray+vLLM setup, each worker process considers itself rank 0 @@ -420,12 +438,15 @@ def _patch_vllm_vit_flash_attn_backend(): **vllm_kwargs, ) - self._create_engine(llm_kwargs) + with self.init_timer.time("create_engine"): + self._create_engine(llm_kwargs) # will be initialized in post_init # used in update_weights_from_ipc_handles self.vllm_device_ids = None + self.init_timer.stop("total_init") + def llm(self): return self.llm @@ -433,6 +454,10 @@ def is_alive(self): """Check if the worker is alive.""" return True + def get_init_timer(self) -> Timer: + """Return init timing for controller aggregation.""" + return self.init_timer + def _merge_stop_strings(self, batch_stop_strings): stop_set: set[str] = set() diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 24bfdb0605..8dc79aea2b 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -15,6 +15,7 @@ import os import time import warnings +from contextlib import nullcontext from typing import Any, Optional, TypeVar import torch @@ -73,6 +74,7 @@ configure_dynamo_cache, get_megatron_checkpoint_dir, ) +from nemo_rl.utils.timer import Timer TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -624,17 +626,23 @@ def setup_model_and_optimizer( load_optimizer: bool = True, get_embedding_ranks=None, # TODO @sahilj: What is this? get_position_embedding_ranks=None, + init_timer: Optional[Timer] = None, ): + def timer_context(label: str): + """Helper to conditionally use timer context.""" + return init_timer.time(label) if init_timer is not None else nullcontext() + state = GlobalState() state.cfg = megatron_cfg # TODO: Freeze state.cfg megatron_cfg.dist.external_gpu_device_mapping = True - initialize_megatron( - cfg=megatron_cfg, - get_embedding_ranks=get_embedding_ranks, - get_position_embedding_ranks=get_position_embedding_ranks, - ) + with timer_context("initialize_megatron"): + initialize_megatron( + cfg=megatron_cfg, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + ) if megatron_cfg.ft and megatron_cfg.ft.enable_ft_package: fault_tolerance.setup(megatron_cfg, state) @@ -731,22 +739,24 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: pre_wrap_hook.extend([composed_peft_hook]) # Model, optimizer, and learning rate. - model = get_model( - megatron_cfg.model, - megatron_cfg.ddp, - use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2, - overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, - data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init, - pre_wrap_hook=pre_wrap_hook, - mixed_precision_wrapper=mixed_precision_wrapper, - ) - if load_optimizer: - optimizer, scheduler = setup_optimizer( - optimizer_config=megatron_cfg.optimizer, - scheduler_config=megatron_cfg.scheduler, - model=model, - use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups, + with timer_context("model_init"): + model = get_model( + megatron_cfg.model, + megatron_cfg.ddp, + use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2, + overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, + data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init, + pre_wrap_hook=pre_wrap_hook, + mixed_precision_wrapper=mixed_precision_wrapper, ) + if load_optimizer: + with timer_context("optimizer_init"): + optimizer, scheduler = setup_optimizer( + optimizer_config=megatron_cfg.optimizer, + scheduler_config=megatron_cfg.scheduler, + model=model, + use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups, + ) else: optimizer = None scheduler = None @@ -774,14 +784,16 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: # Load checkpoint if applicable if should_load_checkpoint: - load_checkpoint( - state, - model, - optimizer, - scheduler, - checkpointing_context=checkpointing_context, - skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2, - ) + with timer_context("load_checkpoint"): + load_checkpoint( + state, + model, + optimizer, + scheduler, + checkpointing_context=checkpointing_context, + skip_load_to_model_and_opt=HAVE_FSDP2 + and megatron_cfg.dist.use_torch_fsdp2, + ) print("Checkpoint loaded") torch.distributed.barrier() diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 6028506f92..752ac138fc 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Measure module import time (import time first for measurement) +import time + +G_MODULE_IMPORT_START_TIME = time.perf_counter() + import contextlib import gc import itertools @@ -83,6 +88,10 @@ ) from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +from nemo_rl.utils.timer import Timer + +# Calculate module import duration after all imports +G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME @contextmanager @@ -158,10 +167,17 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorker.""" + self.init_timer = Timer() + self.init_timer.start("total_init") + + # Record module import time + if "G_MODULE_IMPORT_DURATION" in globals(): + self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION] + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' apply_torch_aten_alias_tensor_patch() - """Initialize the DTensorPolicyWorker.""" self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -183,7 +199,8 @@ def __init__( self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call - torch.distributed.init_process_group(backend="nccl") + with self.init_timer.time("setup_distributed_nccl"): + torch.distributed.init_process_group(backend="nccl") self.rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] @@ -260,26 +277,27 @@ def __init__( # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. model_class = resolve_model_class(model_config.model_type) - full_state_dict = None - if self.rank == 0: - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") - model = model_class.from_pretrained( - model_name, - device_map="cpu", # load weights onto CPU initially - trust_remote_code=True, - config=model_config, - ) - full_state_dict = model.state_dict() - del model - - print(f"[Rank {self.rank}] Initializing empty model for FSDP...") - # All ranks initialize model on meta device, so FSDP can shard it. - # The actual weights will be broadcast from rank 0. - with init_empty_weights(): - self.model = model_class.from_config( - model_config, - trust_remote_code=True, - ) + with self.init_timer.time("model_loading"): + full_state_dict = None + if self.rank == 0: + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + model = model_class.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + trust_remote_code=True, + config=model_config, + ) + full_state_dict = model.state_dict() + del model + + print(f"[Rank {self.rank}] Initializing empty model for FSDP...") + # All ranks initialize model on meta device, so FSDP can shard it. + # The actual weights will be broadcast from rank 0. + with init_empty_weights(): + self.model = model_class.from_config( + model_config, + trust_remote_code=True, + ) if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id @@ -344,30 +362,32 @@ def __init__( # 3) Move to GPU + Composable FSDP # (Initialize device mesh, shard submodules, then shard entire model) # ------------------------------------------------ - self.model = _parallelize_model( - self.model, - self.dp_cp_mesh, - self.tp_mesh, - param_dtype=self.dtype, - sequence_parallel=sequence_parallel_enabled, - cpu_offload=self.cpu_offload, - activation_checkpointing=self.cfg["dtensor_cfg"][ - "activation_checkpointing" - ], - custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], - ) + with self.init_timer.time("model_parallelization"): + self.model = _parallelize_model( + self.model, + self.dp_cp_mesh, + self.tp_mesh, + param_dtype=self.dtype, + sequence_parallel=sequence_parallel_enabled, + cpu_offload=self.cpu_offload, + activation_checkpointing=self.cfg["dtensor_cfg"][ + "activation_checkpointing" + ], + custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], + ) - print(f"[Rank {self.rank}] Loading state dict from rank 0...") - # This will broadcast the state dict from rank 0 to all other ranks - # and load it into the FSDP model. - set_model_state_dict( - self.model, - model_state_dict=full_state_dict, - options=StateDictOptions( - full_state_dict=True, - broadcast_from_rank0=True, - ), - ) + with self.init_timer.time("state_dict_broadcast"): + print(f"[Rank {self.rank}] Loading state dict from rank 0...") + # This will broadcast the state dict from rank 0 to all other ranks + # and load it into the FSDP model. + set_model_state_dict( + self.model, + model_state_dict=full_state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) # Handle tied word embeddings after loading the state dict # We need to actually tie the parameters at the model level @@ -392,15 +412,17 @@ def __init__( self.model = self.move_to_device(self.model, "cpu") if init_reference_model: - self.reference_model_state_dict = get_cpu_state_dict( - self.model.state_dict().items(), pin_memory=True - ) + with self.init_timer.time("reference_model_setup"): + self.reference_model_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) if init_optimizer: - optimizer_cls = get_class(self.cfg["optimizer"]["name"]) - self.optimizer = optimizer_cls( - self.model.parameters(), **self.cfg["optimizer"]["kwargs"] - ) + with self.init_timer.time("optimizer_setup"): + optimizer_cls = get_class(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), **self.cfg["optimizer"]["kwargs"] + ) else: self.optimizer = None @@ -444,6 +466,13 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) + # Stop total init timing + self.init_timer.stop("total_init") + + def get_init_timer(self) -> Timer: + """Return init timing for controller aggregation.""" + return self.init_timer + # Refer to nemo impl. Below is original comment. # based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L113 @staticmethod diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 9342c5c138..69762e6dc9 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Measure module import time (import time first for measurement) +import time + +G_MODULE_IMPORT_START_TIME = time.perf_counter() + import contextlib import gc import warnings @@ -76,6 +81,10 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +from nemo_rl.utils.timer import Timer + +# Calculate module import duration after all imports +G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME def dtensor_params_generator( @@ -217,6 +226,13 @@ def __init__( **kwargs: Any, ): """Initialize the DTensorPolicyWorkerV2.""" + self.init_timer = Timer() + self.init_timer.start("total_init") + + # Record module import time + if "G_MODULE_IMPORT_DURATION" in globals(): + self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION] + # Apply TE patch until TE is upgraded to 2.10.0 apply_transformer_engine_patch() # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' @@ -244,10 +260,11 @@ def __init__( ) # Set up distributed environment (returns FSDP2Manager) - distributed_manager = setup_distributed( - config=config, - runtime_config=runtime_config, - ) + with self.init_timer.time("setup_distributed_nccl"): + distributed_manager = setup_distributed( + config=config, + runtime_config=runtime_config, + ) # Set instance attributes from distributed manager (tuple unpacking for mesh attributes) self.rank = torch.distributed.get_rank() self.device_mesh = distributed_manager.device_mesh @@ -272,17 +289,18 @@ def __init__( ) # Set up model and optimizer - model_and_optimizer_state = setup_model_and_optimizer( - config=config, - tokenizer=tokenizer, - runtime_config=runtime_config, - distributed_manager=distributed_manager, - checkpoint_manager=self.checkpoint_manager, - is_vlm=self.is_vlm, - init_optimizer=init_optimizer, - weights_path=weights_path, - optimizer_path=optimizer_path, - ) + with self.init_timer.time("model_and_optimizer_setup"): + model_and_optimizer_state = setup_model_and_optimizer( + config=config, + tokenizer=tokenizer, + runtime_config=runtime_config, + distributed_manager=distributed_manager, + checkpoint_manager=self.checkpoint_manager, + is_vlm=self.is_vlm, + init_optimizer=init_optimizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + ) # Set instance attributes from model and optimizer state (tuple unpacking) ( @@ -302,7 +320,10 @@ def __init__( # Initialize reference model if requested self.reference_model_state_dict = None if init_reference_model: - self.reference_model_state_dict = setup_reference_model_state(self.model) + with self.init_timer.time("reference_model_setup"): + self.reference_model_state_dict = setup_reference_model_state( + self.model + ) # Set instance attributes from runtime config (tuple unpacking) ( @@ -320,6 +341,13 @@ def __init__( _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config + # Stop total init timing + self.init_timer.stop("total_init") + + def get_init_timer(self) -> Timer: + """Return init timing for controller aggregation.""" + return self.init_timer + def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: if "generation" in self.cfg and self.cfg["generation"] is not None: logits.div_(self.cfg["generation"]["temperature"]) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 37f1d8fefe..7a468e188c 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -11,10 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Measure module import time (import time first for measurement) +import time + +G_MODULE_IMPORT_START_TIME = time.perf_counter() + import gc import os import re -import time import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext @@ -102,6 +107,10 @@ from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +from nemo_rl.utils.timer import Timer + +# Calculate module import duration after all imports +G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -186,6 +195,13 @@ def __init__( **kwargs: Any, ): """Initialize the MegatronPolicyWorker.""" + self.init_timer = Timer() + self.init_timer.start("total_init") + + # Record module import time + if "G_MODULE_IMPORT_DURATION" in globals(): + self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION] + # Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files apply_transformer_engine_patch() @@ -195,16 +211,18 @@ def __init__( self.rank = get_rank_safe() # Step 1: Setup distributed - setup_distributed() + with self.init_timer.time("setup_distributed_nccl"): + setup_distributed() # Step 2: Validate and setup model paths hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( config ) # Handle model import if needed - handle_model_import( - config, hf_model_name, pretrained_path, pt_checkpoint_exists - ) + with self.init_timer.time("hf_model_import"): + handle_model_import( + config, hf_model_name, pretrained_path, pt_checkpoint_exists + ) # Store tokenizer self.tokenizer = tokenizer @@ -241,9 +259,10 @@ def __init__( self.megatron_cfg.validate() # Step 4: Setup Megatron model and components - model_and_optimizer_state = setup_model_and_optimizer( - config, self.megatron_cfg, init_optimizer - ) + with self.init_timer.time("model_and_optimizer_setup"): + model_and_optimizer_state = setup_model_and_optimizer( + config, self.megatron_cfg, init_optimizer, init_timer=self.init_timer + ) self.mcore_state = model_and_optimizer_state.state self.model = model_and_optimizer_state.model @@ -258,26 +277,28 @@ def __init__( # Step 5: Setup reference model if needed if init_reference_model: - self.model = self.move_model(self.model, "cpu") - self.reference_state_dict = setup_reference_model_state( - config, self.megatron_cfg, pretrained_path - ) - self.model = self.move_model(self.model, "cuda") + with self.init_timer.time("reference_model_setup"): + self.model = self.move_model(self.model, "cpu") + self.reference_state_dict = setup_reference_model_state( + config, self.megatron_cfg, pretrained_path + ) + self.model = self.move_model(self.model, "cuda") # Step 6: Finalize setup - ( - self.megatron_tokenizer, - self.megatron_bridge, - self.should_disable_forward_pre_hook, - self.dp_size, - ) = finalize_megatron_setup( - config, - self.megatron_cfg, - hf_model_name, - worker_sharding_annotations, - self.model, - self.optimizer, - ) + with self.init_timer.time("finalize_setup"): + ( + self.megatron_tokenizer, + self.megatron_bridge, + self.should_disable_forward_pre_hook, + self.dp_size, + ) = finalize_megatron_setup( + config, + self.megatron_cfg, + hf_model_name, + worker_sharding_annotations, + self.model, + self.optimizer, + ) # vars used for refit ## will be initialized in prepare_refit_info @@ -292,6 +313,13 @@ def __init__( ## used for streaming update inference engine weights self._held_gather_buffer = None + # Stop total init timing + self.init_timer.stop("total_init") + + def get_init_timer(self) -> Timer: + """Return init timing for controller aggregation.""" + return self.init_timer + def enable_forward_pre_hook(self): assert isinstance(self.model, DistributedDataParallel) self.model.enable_forward_pre_hook() diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index f8e9ad0c6f..413a54481e 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -87,6 +87,7 @@ class LoggerConfig(TypedDict): monitor_gpus: bool gpu_monitoring: GPUMonitoringConfig num_val_samples_to_print: NotRequired[int] + collect_worker_init_timing: NotRequired[bool] class LoggerInterface(ABC): diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 5366d3f339..dccf6857b0 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -11,10 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import sys import time from contextlib import contextmanager -from typing import Callable, Generator, Optional, Sequence, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Sequence, Union + +if TYPE_CHECKING: + from nemo_rl.distributed.worker_groups import RayWorkerGroup import numpy as np @@ -232,6 +237,45 @@ def get_timing_metrics( return results + @staticmethod + def aggregate_max( + timers: list["Timer"], + reduction_op: str = "sum", + ) -> dict[str, float]: + """Aggregate multiple timers by taking the maximum value for each label. + + Args: + timers: List of Timer objects to aggregate + reduction_op: Reduction operation to apply to each timer's measurements before aggregation. + Valid options are: "mean", "median", "min", "max", "std", "sum", "count" + + Returns: + A dictionary mapping labels to the maximum value across all timers for that label + + Raises: + ValueError: If an invalid reduction operation is provided + """ + if not timers: + return {} + + # Collect all unique labels across all timers + all_labels = set() + for timer in timers: + all_labels.update(timer._timers.keys()) + + # Aggregate by taking max for each label + aggregated: dict[str, float] = {} + for label in all_labels: + max_value = float("-inf") + for timer in timers: + if label in timer._timers: + # Apply reduction to this timer's measurements for this label + reduced_value = timer.reduce(label, reduction_op) + max_value = max(max_value, reduced_value) + aggregated[label] = max_value + + return aggregated + def reset(self, label: Optional[str] = None) -> None: """Reset timings for the specified label or all labels. @@ -248,6 +292,31 @@ def reset(self, label: Optional[str] = None) -> None: self._start_times = {} +def save_worker_init_timing( + worker_groups: dict[str, "RayWorkerGroup"], + output_path: Union[str, Path], +) -> None: + """Collect and save initialization timing from multiple worker groups. + + Args: + worker_groups: Dict mapping prefix to worker_group. + output_path: Path to save the JSON file. + """ + timings: dict[str, float] = {} + metadata: dict[str, Any] = {"timestamp": time.time()} + + for prefix, worker_group in worker_groups.items(): + for k, v in worker_group.collect_init_timing().items(): + timings[f"{prefix}/{k}"] = v + metadata[f"num_{prefix}_workers"] = len(worker_group.workers) + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump({"timings": timings, "metadata": metadata}, f, indent=2) + print(f"✅ Saved worker init timing to {output_path}") + + def convert_to_seconds(time_string: str) -> int: """Converts a time string in the format 'DD:HH:MM:SS' to total seconds. diff --git a/tests/functional/grpo_megatron_worker_timing.sh b/tests/functional/grpo_megatron_worker_timing.sh new file mode 100755 index 0000000000..d0486b7491 --- /dev/null +++ b/tests/functional/grpo_megatron_worker_timing.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +# Using Qwen2.5-0.5B instead of Qwen3-0.6B because the latter is not supported by Megatron yet +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + --config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \ + policy.model_name=Qwen/Qwen2.5-0.5B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.logprob_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=1 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.collect_worker_init_timing=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Find the timing file in the exp_* subdirectory (log_dir gets exp_XXX appended by run_grpo.py) +TIMING_FILE=$(find $LOG_DIR -name "worker_init_timing.json" -type f 2>/dev/null | head -1) + +# Check that worker_init_timing.json was created +if [ -z "$TIMING_FILE" ] || [ ! -f "$TIMING_FILE" ]; then + echo "ERROR: Worker init timing file not found in $LOG_DIR (searched exp_* subdirs)" + exit 1 +fi + +# Verify the JSON file has expected structure and timing metrics +uv run python -c " +import json +import sys + +with open('$TIMING_FILE') as f: + data = json.load(f) + +# Check top-level structure +assert 'timings' in data, 'Missing timings key' +assert 'metadata' in data, 'Missing metadata key' +assert 'num_policy_workers' in data['metadata'], 'Missing num_policy_workers in metadata' +assert 'num_vllm_workers' in data['metadata'], 'Missing num_vllm_workers in metadata' +assert len(data['timings']) > 0, 'No timing data found' + +# Check for expected Megatron policy worker timing labels (prefixed with 'policy/') +expected_policy_labels = [ + 'policy/module_import', + 'policy/total_init', + 'policy/setup_distributed_nccl', + 'policy/model_and_optimizer_setup', +] + +# Check for expected vLLM generation worker timing labels (prefixed with 'vllm/') +expected_vllm_labels = [ + 'vllm/module_import', + 'vllm/total_init', + 'vllm/create_engine', +] + +expected_labels = expected_policy_labels + expected_vllm_labels +missing_labels = [label for label in expected_labels if label not in data['timings']] +if missing_labels: + print(f'ERROR: Missing expected timing labels: {missing_labels}', file=sys.stderr) + print(f'Available labels: {list(data[\"timings\"].keys())}', file=sys.stderr) + sys.exit(1) + +# Validate that timing values are reasonable (positive and less than 1000s) +for label, value in data['timings'].items(): + assert isinstance(value, (int, float)), f'Timing value for {label} is not a number: {value}' + assert value >= 0, f'Timing value for {label} is negative: {value}' + assert value < 1000, f'Timing value for {label} is unreasonably large (>1000s): {value}' + +print('✅ Worker init timing file validated successfully') +print(f' - Number of timing labels: {len(data[\"timings\"])}') +print(f' - Number of policy workers: {data[\"metadata\"][\"num_policy_workers\"]}') +print(f' - Number of vLLM workers: {data[\"metadata\"][\"num_vllm_workers\"]}') +print(' - Timing breakdown:') +for label, value in sorted(data['timings'].items()): + print(f' • {label}: {value:.4f}s') +" diff --git a/tests/functional/grpo_worker_timing.sh b/tests/functional/grpo_worker_timing.sh new file mode 100755 index 0000000000..167d58a947 --- /dev/null +++ b/tests/functional/grpo_worker_timing.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=1 \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.tensorboard_enabled=false \ + logger.collect_worker_init_timing=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Find the timing file in the exp_* subdirectory (log_dir gets exp_XXX appended by run_grpo.py) +TIMING_FILE=$(find $LOG_DIR -name "worker_init_timing.json" -type f 2>/dev/null | head -1) + +# Check that worker_init_timing.json was created +if [ -z "$TIMING_FILE" ] || [ ! -f "$TIMING_FILE" ]; then + echo "ERROR: Worker init timing file not found in $LOG_DIR (searched exp_* subdirs)" + exit 1 +fi + +# Verify the JSON file has expected structure and timing metrics +uv run python -c " +import json +import sys + +with open('$TIMING_FILE') as f: + data = json.load(f) + +# Check top-level structure +assert 'timings' in data, 'Missing timings key' +assert 'metadata' in data, 'Missing metadata key' +assert 'num_policy_workers' in data['metadata'], 'Missing num_policy_workers in metadata' +assert len(data['timings']) > 0, 'No timing data found' + +# Check for at least some expected timing labels (prefixed with 'policy/' or 'vllm/') +common_labels = ['policy/total_init', 'vllm/total_init'] +has_common_label = any(label in data['timings'] for label in common_labels) +if not has_common_label: + print(f'WARNING: No common timing labels found. Available labels: {list(data[\"timings\"].keys())}', file=sys.stderr) + +# Validate that timing values are reasonable (positive and less than 1000s) +for label, value in data['timings'].items(): + assert isinstance(value, (int, float)), f'Timing value for {label} is not a number: {value}' + assert value >= 0, f'Timing value for {label} is negative: {value}' + assert value < 1000, f'Timing value for {label} is unreasonably large (>1000s): {value}' + +print('✅ Worker init timing file validated successfully') +print(f' - Number of timing labels: {len(data[\"timings\"])}') +print(f' - Number of policy workers: {data[\"metadata\"][\"num_policy_workers\"]}') +print(' - Timing breakdown:') +for label, value in sorted(data['timings'].items()): + print(f' • {label}: {value:.4f}s') +" diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py index 041193b777..d973df55a6 100644 --- a/tests/unit/utils/test_timer.py +++ b/tests/unit/utils/test_timer.py @@ -233,3 +233,81 @@ def test_iteration_tracking(self): checker.mark_iteration() assert len(checker.iteration_times) == 1 assert checker.iteration_times[0] > 0 + + +class TestTimerExtensions: + """Test suite for aggregate_max method.""" + + def test_aggregate_max_basic(self): + """Test basic aggregate_max functionality.""" + # Create multiple timers with different measurements + timer1 = Timer() + timer1._timers["init"] = [1.0, 2.0] # sum = 3.0 + timer1._timers["load"] = [5.0] # sum = 5.0 + + timer2 = Timer() + timer2._timers["init"] = [3.0, 4.0] # sum = 7.0 + timer2._timers["load"] = [2.0] # sum = 2.0 + + timer3 = Timer() + timer3._timers["init"] = [1.5, 1.5] # sum = 3.0 + timer3._timers["process"] = [10.0] # sum = 10.0 + + # Aggregate using max + result = Timer.aggregate_max([timer1, timer2, timer3], reduction_op="sum") + + # Verify max values are selected for each label + assert result["init"] == 7.0 # max of [3.0, 7.0, 3.0] + assert result["load"] == 5.0 # max of [5.0, 2.0] + assert result["process"] == 10.0 # only in timer3 + + def test_aggregate_max_empty_list(self): + """Test aggregate_max with empty timer list.""" + result = Timer.aggregate_max([]) + assert result == {} + + def test_aggregate_max_single_timer(self): + """Test aggregate_max with a single timer.""" + timer = Timer() + timer._timers["operation"] = [1.0, 2.0, 3.0] + + result = Timer.aggregate_max([timer], reduction_op="mean") + assert result["operation"] == 2.0 # mean of [1, 2, 3] + + def test_aggregate_max_different_reduction_ops(self): + """Test aggregate_max with different reduction operations.""" + timer1 = Timer() + timer1._timers["op"] = [1.0, 2.0, 3.0] # mean=2.0, max=3.0, min=1.0 + + timer2 = Timer() + timer2._timers["op"] = [4.0, 5.0, 6.0] # mean=5.0, max=6.0, min=4.0 + + # Test with mean reduction + result_mean = Timer.aggregate_max([timer1, timer2], reduction_op="mean") + assert result_mean["op"] == 5.0 # max of [2.0, 5.0] + + # Test with max reduction + result_max = Timer.aggregate_max([timer1, timer2], reduction_op="max") + assert result_max["op"] == 6.0 # max of [3.0, 6.0] + + # Test with min reduction + result_min = Timer.aggregate_max([timer1, timer2], reduction_op="min") + assert result_min["op"] == 4.0 # max of [1.0, 4.0] + + def test_aggregate_max_disjoint_labels(self): + """Test aggregate_max when timers have completely different labels.""" + timer1 = Timer() + timer1._timers["operation_a"] = [1.0] + + timer2 = Timer() + timer2._timers["operation_b"] = [2.0] + + timer3 = Timer() + timer3._timers["operation_c"] = [3.0] + + result = Timer.aggregate_max([timer1, timer2, timer3], reduction_op="sum") + + # All labels should be present with their respective values + assert result["operation_a"] == 1.0 + assert result["operation_b"] == 2.0 + assert result["operation_c"] == 3.0