Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ logger:
mlflow_enabled: false # Disable MLflow logging
swanlab_enabled: false # Disable SwanLab logging
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
collect_worker_init_timing: false # If true, collects and saves worker initialization timing to JSON
wandb:
project: "grpo-dev"
name: "grpo-dev-logger"
Expand Down
12 changes: 11 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)
from nemo_rl.utils.memory_tracker import MemoryTracker
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import TimeoutChecker, Timer
from nemo_rl.utils.timer import TimeoutChecker, Timer, save_worker_init_timing
from nemo_rl.utils.venvs import create_local_venv_on_each_node

# ===============================================================================
Expand Down Expand Up @@ -655,6 +655,16 @@ def initialize_generation_with_policy(
ray.get(futures_train + futures_inference)
worker_init_timing_metrics["collective_init_time_s"] = time.perf_counter() - t0

# Collect worker initialization timing if enabled
if master_config["logger"].get("collect_worker_init_timing", False):
worker_groups = {"policy": policy.worker_group}
if policy_generation is not None:
worker_groups["vllm"] = policy_generation.worker_group
save_worker_init_timing(
worker_groups,
Path(master_config["logger"]["log_dir"]) / "worker_init_timing.json",
)

# prepare refit info
state_dict_info = policy.prepare_refit_info()
if policy_generation is not None:
Expand Down
28 changes: 28 additions & 0 deletions nemo_rl/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.distributed.worker_group_utils import recursive_merge_options
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.venvs import (
create_local_venv_on_each_node,
)
Expand Down Expand Up @@ -1029,3 +1030,30 @@ def shutdown(
self._worker_metadata = []

return success

def collect_init_timing(self) -> dict[str, float]:
"""Collect and aggregate initialization timing from all workers.

Returns:
dict[str, float]: Dictionary mapping timing labels to aggregated max values.
Returns empty dict if workers don't have timing or on error.
"""
if not self._workers:
return {}

# Collect timing from all workers
timing_futures = []
for worker in self._workers:
if hasattr(worker, "get_init_timer"):
timing_futures.append(worker.get_init_timer.remote())

if not timing_futures:
return {}

# Get all timers
timers = ray.get(timing_futures)

# Aggregate using max across workers, sum within each worker
aggregated = Timer.aggregate_max(timers, reduction_op="sum")

return aggregated
27 changes: 26 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Measure module import time (import time first for measurement)
import time

G_MODULE_IMPORT_START_TIME = time.perf_counter()

import copy
import gc
import os
Expand All @@ -35,6 +40,10 @@
from nemo_rl.models.huggingface.common import ModelFlag
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled
from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.timer import Timer

# Calculate module import duration after all imports
G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME


# Use a base class to share some functions to avoid code duplication.
Expand Down Expand Up @@ -142,6 +151,14 @@ def __init__(
self.fraction_of_gpus = fraction_of_gpus
self.is_model_owner = bundle_indices is not None

# Initialize timer for tracking initialization stages
self.init_timer = Timer()
self.init_timer.start("total_init")

# Record module import time
if "G_MODULE_IMPORT_DURATION" in globals():
self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION]

# Store the Python executable being used by this worker
self.py_executable = sys.executable

Expand All @@ -151,6 +168,7 @@ def __init__(
self.tokenizer = None
self.rank = 0
self.world_size = 1
self.init_timer.stop("total_init")
return

# In Ray+vLLM setup, each worker process considers itself rank 0
Expand Down Expand Up @@ -420,19 +438,26 @@ def _patch_vllm_vit_flash_attn_backend():
**vllm_kwargs,
)

self._create_engine(llm_kwargs)
with self.init_timer.time("create_engine"):
self._create_engine(llm_kwargs)

# will be initialized in post_init
# used in update_weights_from_ipc_handles
self.vllm_device_ids = None

self.init_timer.stop("total_init")

def llm(self):
return self.llm

def is_alive(self):
"""Check if the worker is alive."""
return True

def get_init_timer(self) -> Timer:
"""Return init timing for controller aggregation."""
return self.init_timer

def _merge_stop_strings(self, batch_stop_strings):
stop_set: set[str] = set()

Expand Down
68 changes: 40 additions & 28 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import time
import warnings
from contextlib import nullcontext
from typing import Any, Optional, TypeVar

import torch
Expand Down Expand Up @@ -73,6 +74,7 @@
configure_dynamo_cache,
get_megatron_checkpoint_dir,
)
from nemo_rl.utils.timer import Timer

TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)

Expand Down Expand Up @@ -624,17 +626,23 @@ def setup_model_and_optimizer(
load_optimizer: bool = True,
get_embedding_ranks=None, # TODO @sahilj: What is this?
get_position_embedding_ranks=None,
init_timer: Optional[Timer] = None,
):
def timer_context(label: str):
"""Helper to conditionally use timer context."""
return init_timer.time(label) if init_timer is not None else nullcontext()

state = GlobalState()
state.cfg = megatron_cfg
# TODO: Freeze state.cfg

megatron_cfg.dist.external_gpu_device_mapping = True
initialize_megatron(
cfg=megatron_cfg,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
)
with timer_context("initialize_megatron"):
initialize_megatron(
cfg=megatron_cfg,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
)

if megatron_cfg.ft and megatron_cfg.ft.enable_ft_package:
fault_tolerance.setup(megatron_cfg, state)
Expand Down Expand Up @@ -731,22 +739,24 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
pre_wrap_hook.extend([composed_peft_hook])

# Model, optimizer, and learning rate.
model = get_model(
megatron_cfg.model,
megatron_cfg.ddp,
use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2,
overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step,
data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init,
pre_wrap_hook=pre_wrap_hook,
mixed_precision_wrapper=mixed_precision_wrapper,
)
if load_optimizer:
optimizer, scheduler = setup_optimizer(
optimizer_config=megatron_cfg.optimizer,
scheduler_config=megatron_cfg.scheduler,
model=model,
use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups,
with timer_context("model_init"):
model = get_model(
megatron_cfg.model,
megatron_cfg.ddp,
use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2,
overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step,
data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init,
pre_wrap_hook=pre_wrap_hook,
mixed_precision_wrapper=mixed_precision_wrapper,
)
if load_optimizer:
with timer_context("optimizer_init"):
optimizer, scheduler = setup_optimizer(
optimizer_config=megatron_cfg.optimizer,
scheduler_config=megatron_cfg.scheduler,
model=model,
use_gloo_process_groups=megatron_cfg.dist.use_gloo_process_groups,
)
else:
optimizer = None
scheduler = None
Expand Down Expand Up @@ -774,14 +784,16 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:

# Load checkpoint if applicable
if should_load_checkpoint:
load_checkpoint(
state,
model,
optimizer,
scheduler,
checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2,
)
with timer_context("load_checkpoint"):
load_checkpoint(
state,
model,
optimizer,
scheduler,
checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2
and megatron_cfg.dist.use_torch_fsdp2,
)
print("Checkpoint loaded")
torch.distributed.barrier()

Expand Down
Loading
Loading