From 438f2850b3c5e410added59d54608688ac40d3d6 Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Wed, 4 Feb 2026 13:05:56 +0000 Subject: [PATCH 1/2] feat: Add perfetto tracing for async GRPO training Signed-off-by: Georg Stefan Schmid --- nemo_rl/algorithms/async_utils.py | 164 ++++++++++----- nemo_rl/algorithms/grpo.py | 61 ++++-- nemo_rl/utils/trace.py | 335 ++++++++++++++++++++++++++++++ tests/unit/utils/test_trace.py | 157 ++++++++++++++ 4 files changed, 649 insertions(+), 68 deletions(-) create mode 100644 nemo_rl/utils/trace.py create mode 100644 tests/unit/utils/test_trace.py diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py index c1ce9ab762..79e125fcf7 100644 --- a/nemo_rl/algorithms/async_utils.py +++ b/nemo_rl/algorithms/async_utils.py @@ -28,6 +28,7 @@ run_async_multi_turn_rollout, ) from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.utils.trace import define_collect_trace, new_tracer, Tracer TokenizerType = PreTrainedTokenizerBase @@ -52,6 +53,12 @@ def __init__(self, max_size: int): self.last_target_weight_already_generated = -1 self._lock = _threading.Lock() + self._tracer = new_tracer("replay_buffer") + + @define_collect_trace + def collect_trace(self): + return self._tracer.get_events() + def push_with_wait_signal( self, trajectory: dict[str, Any], @@ -65,7 +72,14 @@ def push_with_wait_signal( weight_version: version of the model weights used for generation target_weight_version: version of the model weights this trajectory is intended for training """ - with self._lock: + span = self._tracer.span( + "push_with_wait_signal", + metadata={ + "weight_version": weight_version, + "target_weight_version": target_weight_version, + } + ) + with span, self._lock: if len(self.trajectories) >= self.max_size: return "full" @@ -115,7 +129,7 @@ def sample( Returns: Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data """ - with self._lock: + with self._tracer.span("sample"), self._lock: if not self.trajectories: return None @@ -291,6 +305,18 @@ def __init__( # Track which target weights are currently being generated (globally) self._generating_targets: set[int] = set() + self._tracer = new_tracer("trajectory_collector") + self._loop_tracer = new_tracer("trajectory_collector_loop") + self._worker_tracers = [] + + @define_collect_trace + def collect_trace(self): + events = self._tracer.get_events() + events.extend(self._loop_tracer.get_events()) + for worker_tracer in self._worker_tracers: + events.extend(worker_tracer.get_events()) + return events + def _calculate_target_weights(self, generation_weight_version: int) -> list[int]: """Calculate target weight versions for given generation weight version. @@ -351,6 +377,13 @@ def set_weight_version(self, version: int) -> None: print(f"πŸ”„ Updated weight version to {version}, resuming collection") else: print(f"πŸ”„ Updated weight version to {version}") + self._tracer.add_instant_event( + "set_weight_version", + metadata={ + "version": version, + "was_paused": was_paused, + } + ) def _should_pause_for_generation_limits(self) -> bool: """Check if collection should be paused due to generation limits.""" @@ -454,6 +487,14 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: generation_weight_version = self.current_weight_version num_generations = self.master_config["grpo"]["num_generations_per_prompt"] num_prompts = batch.size + self._loop_tracer.start_span( + "process_batch", + metadata={ + "generation_weight_version": generation_weight_version, + "num_generations": num_generations, + "num_prompts": num_prompts, + }, + ) # Get the next target weight that needs generation target_weight = self._get_next_target_for_generation( @@ -491,6 +532,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: repeated_batch = single_prompt_batch.repeat_interleave(num_generations) self._inflight_sema.acquire() + worker_tracer = new_tracer(f"trajectory_collector_worker_prompt{prompt_idx}") + if self._tracer.enabled: + self._worker_tracers.append(worker_tracer) worker = _threading.Thread( target=self._run_prompt_group_worker, args=( @@ -498,12 +542,21 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: generation_weight_version, target_weight, prompt_idx, + worker_tracer, ), daemon=True, ) with self._threads_lock: self._inflight_threads.add(worker) worker.start() + self._loop_tracer.add_instant_event( + "launch run_prompt_group_worker", + metadata={ + "generation_weight_version": generation_weight_version, + "target_weight_version": target_weight, + "prompt_idx": prompt_idx, + }, + ) self._cleanup_finished_threads() @@ -513,6 +566,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: traceback.print_exc() + finally: + self._loop_tracer.end_span("process_batch") + def get_weight_version(self) -> int: return self.current_weight_version @@ -577,50 +633,52 @@ def prepare_for_refit(self) -> None: def resume_after_refit(self) -> None: """Resume new generation starts after refit is complete.""" - print("πŸ”„ Resuming generation starts after refit") - - # Invalidate&recompute vLLM caches after the in-flight weight updates if - # recompute_kv_cache_after_weight_updates is True (AREAL-style implementation). - # Otherwise, keep using the stale KV caches (Magistral-style implementation). - async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {}) - if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get( - "recompute_kv_cache_after_weight_updates", False - ): - try: - print("πŸ”„ Invalidating vLLM prefix/KV caches after weight update") - invalidated = self.policy_generation.invalidate_kv_cache() - if invalidated: - print("βœ… Invalidated vLLM prefix/KV caches after weight update") - else: - print( - "⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers" - ) - except Exception as e: - print(f"⚠️ Failed to invalidate vLLM caches: {e}") + with self._tracer.span("resume_after_refit"): + print("πŸ”„ Resuming generation starts after refit") + + # Invalidate&recompute vLLM caches after the in-flight weight updates if + # recompute_kv_cache_after_weight_updates is True (AREAL-style implementation). + # Otherwise, keep using the stale KV caches (Magistral-style implementation). + async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {}) + if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get( + "recompute_kv_cache_after_weight_updates", False + ): + try: + print("πŸ”„ Invalidating vLLM prefix/KV caches after weight update") + invalidated = self.policy_generation.invalidate_kv_cache() + if invalidated: + print("βœ… Invalidated vLLM prefix/KV caches after weight update") + else: + print( + "⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers" + ) + except Exception as e: + print(f"⚠️ Failed to invalidate vLLM caches: {e}") - self._refit_pause_cleared.set() + self._refit_pause_cleared.set() def wait_for_pending_generations(self) -> None: """Wait for all in-flight generation threads to complete.""" - start_time = time.time() + with self._tracer.span("wait_for_pending_generations"): + start_time = time.time() - while True: - with self._threads_lock: - finished = {t for t in self._inflight_threads if not t.is_alive()} - for t in finished: - self._inflight_threads.remove(t) + while True: + with self._threads_lock: + finished = {t for t in self._inflight_threads if not t.is_alive()} + for t in finished: + self._inflight_threads.remove(t) - pending_count = len(self._inflight_threads) + pending_count = len(self._inflight_threads) - if pending_count == 0: - print("βœ… All generation threads completed") - break + if pending_count == 0: + print("βœ… All generation threads completed") + break - elapsed = time.time() - start_time - print( - f"⏳ Waiting for {pending_count} pending generation threads... ({elapsed:.1f}s elapsed)" - ) - time.sleep(0.5) + elapsed = time.time() - start_time + print( + f"⏳ Waiting for {pending_count} pending generation threads... ({elapsed:.1f}s elapsed)" + ) + time.sleep(0.5) def get_dataloader_state(self) -> dict: """Get the current dataloader state for checkpointing.""" @@ -640,19 +698,29 @@ def _run_prompt_group_worker( generation_weight_version: int, target_weight_version: int, prompt_idx: int, + worker_tracer: Tracer, ) -> None: + worker_tracer.start_span( + "run_prompt_group_worker", + metadata={ + "generation_weight_version": generation_weight_version, + "target_weight_version": target_weight_version, + "prompt_idx": prompt_idx, + } + ) try: # Run rollout for this prompt group # Async engine supports concurrent generation; avoid locking - final_batch, rollout_metrics = run_async_multi_turn_rollout( - policy_generation=self.policy_generation, - input_batch=repeated_batch, - tokenizer=self.tokenizer, - task_to_env=self.task_to_env, - max_seq_len=self.master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"], - greedy=False, - ) + with worker_tracer.span("run_async_multi_turn_rollout"): + final_batch, rollout_metrics = run_async_multi_turn_rollout( + policy_generation=self.policy_generation, + input_batch=repeated_batch, + tokenizer=self.tokenizer, + task_to_env=self.task_to_env, + max_seq_len=self.master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"], + greedy=False, + ) # Move to CPU and push to buffer (avoid blocking on GC/push) final_batch_cpu = final_batch.to("cpu") @@ -728,3 +796,5 @@ def _run_prompt_group_worker( import traceback traceback.print_exc() + + worker_tracer.end_span("run_prompt_group_worker") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f223fa091c..29ff65e0ea 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -17,6 +17,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext +from functools import partial from pathlib import Path from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast @@ -75,6 +76,8 @@ from nemo_rl.utils.memory_tracker import MemoryTracker from nemo_rl.utils.nsys import maybe_gpu_profile_step from nemo_rl.utils.timer import TimeoutChecker, Timer +from nemo_rl.utils.trace import new_tracer, save_trace +from nemo_rl.utils.trace import trace_and_time as _trace_and_time from nemo_rl.utils.venvs import create_local_venv_on_each_node # =============================================================================== @@ -1497,7 +1500,6 @@ def grpo_train( ) train_data.update(extra_multimodal_data) train_data.to("cpu") - metrics_logging_data["content"] = flat_messages["content"] memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) @@ -2111,7 +2113,9 @@ def async_grpo_train( # Import async utilities only when needed from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer + tracer = new_tracer() timer = Timer() + trace_and_time = partial(_trace_and_time, tracer, timer) timeout = TimeoutChecker( timeout=master_config["checkpointing"]["checkpoint_must_save_by"], fit_last_save_time=True, @@ -2235,7 +2239,8 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: print("πŸ”„ Refitting policy generation with actual model weights...") try: - refit_policy_generation(policy, policy_generation, colocated_inference) + with tracer.span("refit"): + refit_policy_generation(policy, policy_generation, colocated_inference) print("βœ… Policy generation refit completed successfully") POLICY_GENERATION_STALE = False except Exception as e: @@ -2322,10 +2327,10 @@ def async_grpo_train( if policy != policy_generation: maybe_gpu_profile_step(policy_generation, step + 1) - with timer.time("total_step_time"): + with trace_and_time("step", time_label="total_step_time", metadata={"step": step}): # Sample trajectories from replay buffer print("πŸ“¦ Sampling from replay buffer...") - with timer.time("exposed_generation"): + with trace_and_time("sample", time_label="exposed_generation"): buffer_size_current = ray.get(replay_buffer.size.remote()) print( f"πŸ“Š Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}" @@ -2413,7 +2418,7 @@ def async_grpo_train( print(f"Got trajectory batch (size: {repeated_batch.size})") print("β–Ά Processing rewards...") - with timer.time("reward_calculation"): + with trace_and_time("reward_calculation"): prompt_only_message_logs = [] for message_log in repeated_batch["message_log"]: prompt_only_log = [] @@ -2465,7 +2470,7 @@ def async_grpo_train( ) # Prepare training data (same as sync version) - with timer.time("data_processing"): + with trace_and_time("data_processing"): # Add loss mask and advantages to each message for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): @@ -2508,42 +2513,47 @@ def async_grpo_train( train_data.to("cpu") # Training phase (same as sync version) + tracer.start_span("training") print("β–Ά Preparing for logprob inference...") - with timer.time("logprob_inference_prep"): + with trace_and_time("logprob_inference_prep"): policy.prepare_for_lp_inference() print("β–Ά Computing logprobs...") - with timer.time("policy_and_reference_logprobs"): - fprop_logprobs = policy.get_logprobs( - train_data, - timer=timer, - )["logprobs"] - reference_logprobs = policy.get_reference_policy_logprobs( - train_data, - timer=timer, - )["reference_logprobs"] + with trace_and_time("policy_and_reference_logprobs"): + with tracer.span("policy_logprobs"): + fprop_logprobs = policy.get_logprobs( + train_data, + timer=timer, + )["logprobs"] + + with tracer.span("reference_policy_logprobs"): + reference_logprobs = policy.get_reference_policy_logprobs( + train_data, + timer=timer, + )["reference_logprobs"] train_data["prev_logprobs"] = fprop_logprobs train_data["reference_policy_logprobs"] = reference_logprobs print("β–Ά Preparing for training...") - with timer.time("training_prep"): + with trace_and_time("training_prep"): policy.prepare_for_training() POLICY_GENERATION_STALE = True print("β–Ά Training policy...") - with timer.time("policy_training"): + with trace_and_time("policy_training"): train_results = policy.train( train_data, loss_fn, timer=timer, ) + tracer.end_span("training") print("πŸ”„ Synchronizing policy weights to trajectory collector…") generation_logger_metrics = None if NEED_REFIT: # Measure pending-generation wait as exposed_generation time print("πŸ”„ Coordinating with trajectory collector before refit...") - with timer.time("exposed_generation"): + with trace_and_time("prepare_for_refit", time_label="exposed_generation"): ray.get(trajectory_collector.prepare_for_refit.remote()) # Collect generation logger metrics for performance reporting @@ -2555,7 +2565,7 @@ def async_grpo_train( # Only the actual refit/weight transfer should be counted as weight_sync print("πŸ”„ Performing policy generation refit...") - with timer.time("weight_sync"): + with trace_and_time("weight_sync"): refit_policy_generation( policy, policy_generation, colocated_inference ) @@ -2574,6 +2584,7 @@ def async_grpo_train( val_metrics, validation_timings = None, None is_last_step = step + 1 == master_config["grpo"]["max_num_steps"] + tracer.start_span("validation") if val_period > 0 and (step + 1) % val_period == 0: # Pause trajectory collection during validation to reduce memory pressure trajectory_collector.pause.remote() @@ -2608,6 +2619,8 @@ def async_grpo_train( # Resume trajectory collection after validation trajectory_collector.resume.remote() + tracer.end_span("validation") + # Get flat advantages and token mask for masked metrics computation flat_advantages = flat_messages["advantages"] flat_token_mask = flat_messages["token_loss_mask"] @@ -2720,7 +2733,7 @@ def async_grpo_train( metric_name ] - with timer.time("checkpointing"): + with trace_and_time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, grpo_save_state, master_config @@ -2843,6 +2856,12 @@ def async_grpo_train( traceback.print_exc() finally: + try: + ray.get(trajectory_collector.pause.remote()) + save_trace(tracer.get_events(), (replay_buffer, trajectory_collector)) + except Exception as e: + print(f"Error saving tracer events: {e}") + # Clean up print("πŸ›‘ Stopping trajectory collection...") try: diff --git a/nemo_rl/utils/trace.py b/nemo_rl/utils/trace.py new file mode 100644 index 0000000000..3601b63afa --- /dev/null +++ b/nemo_rl/utils/trace.py @@ -0,0 +1,335 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tracing infrastructure for NemoRL training with Perfetto/Chrome Trace Format support. + +This module provides lightweight tracing for GRPO and other RL training algorithms, +generating Chrome Trace Event Format JSON files that can be visualized in Perfetto UI +(https://ui.perfetto.dev) or chrome://tracing. + +Usage: + # Enable tracing via environment variable + export NEMORL_TRACE_ENABLED=1 + export NEMORL_TRACE_FILE=/path/to/trace.json + + # In your training code + from nemo_rl.utils.trace import new_tracer, save_trace + + tracer = new_tracer("grpo_driver") + + for step in range(42): + with tracer.span("step", metadata={"step": step}): + with tracer.span("generation"): + # generation code + pass + with tracer.span("training"): + # training code + pass + + save_trace(tracer.get_events(), actors=()) +""" +import json +import os +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Optional + +try: + import ray + RAY_AVAILABLE = True +except ImportError: + RAY_AVAILABLE = False + +from nemo_rl.utils.timer import Timer + +Event = dict[str, Any] + + +class Tracer: + """Lightweight tracer for NemoRL training that outputs Chrome Trace Format. + + This tracer accumulates timing events during training and exports them to + a JSON file compatible with Perfetto UI and chrome://tracing. + """ + + def __init__( + self, + enabled: bool = False, + name: str = "", + ): + """Initialize the tracer. + + Args: + enabled: Whether tracing is enabled. If False, all operations are no-ops. + """ + self._enabled = enabled + self._events: list[Event] = [] + self._events_lock = threading.Lock() + self._span_stack: list[tuple[str, float, dict[str, Any]]] = [] + + self._pid = os.getpid() + self._name = name + # We only intitialize tid upon the first span. This allows us to create the + # tracer on a different thread from the one it'll be used on. + self._tid = None + + def _ensure_tid(self): + tid = threading.current_thread().native_id + if self._tid is None: + self._tid = tid + else: + assert tid == self._tid, f"Tracer used on different threads: {tid=} <> {self._tid=}" + + def start_span(self, name: str, metadata: Optional[dict[str, Any]] = None) -> None: + """Start a traced span. Make sure to call end_span(name) when done. + + Args: + name: Name of the span (e.g., "generation", "training", "step") + metadata: Optional metadata to attach to the span (e.g., step numbers) + + Example: + tracer.start_span("step", metadata={"step": step}) + # ... + tracer.end_span("step") + """ + if not self._enabled: + return + self._ensure_tid() + + start_ts = time.monotonic() + if metadata is None: + metadata = {} + metadata["tracer_name"] = self._name + self._span_stack.append((name, start_ts, metadata)) + + begin_event = { + "name": name, + "ph": "B", # Begin phase + "ts": int(start_ts * 1_000_000), # microseconds + "pid": self._pid, + "tid": self._tid, + "args": metadata, + } + + with self._events_lock: + self._events.append(begin_event) + + def end_span(self, name: str) -> None: + """End the most recently started span. + + Args: + name: Optional name to verify we're ending the right span. + + Raises: + ValueError: If name doesn't match current span or no span is active + """ + if not self._enabled: + return + + if not self._span_stack: + raise ValueError(f"No active span to end (expected {name=})") + + span_name, _span_start, _span_metadata = self._span_stack.pop() + + if name != span_name: + raise ValueError(f"Span name mismatch: expected '{name}', got '{span_name}'") + + end_event = { + "name": span_name, + "ph": "E", # End phase + "ts": int(time.monotonic() * 1_000_000), # microseconds + "pid": self._pid, + "tid": self._tid, + } + + with self._events_lock: + self._events.append(end_event) + + @contextmanager + def span( + self, + name: str, + metadata: Optional[dict[str, Any]] = None, + ) -> Generator[None, None, None]: + """Create a traced span (timing block) with optional metadata. + + Args: + name: Name of the span (e.g., "generation", "training", "step") + metadata: Optional metadata to attach to the span (e.g., step numbers) + + Example: + with tracer.span("step", metadata={"step": step}): + # ... + """ + if not self._enabled: + yield + return + self._ensure_tid() + + self.start_span(name, metadata) + try: + yield + finally: + self.end_span(name) + + def add_instant_event( + self, + name: str, + metadata: Optional[dict[str, Any]] = None, + scope: str = "t", + ) -> None: + """Add an instant event (point-in-time marker) to the trace. + + Args: + name: Name of the instant event + metadata: Optional metadata to attach + scope: Scope of the event ("t" = thread, "p" = process, "g" = global) + """ + if not self._enabled: + return + self._ensure_tid() + assert scope in ("t", "p", "g") + + if metadata is None: + metadata = {} + metadata["tracer_name"] = self._name + + event = { + "name": name, + "ph": "i", # Instant event + "ts": int(time.monotonic() * 1_000_000), + "pid": self._pid, + "tid": self._tid, + "s": scope, + "args": metadata, + } + + with self._events_lock: + self._events.append(event) + + def add_counter( + self, + name: str, + value: float, + metadata: Optional[dict[str, Any]] = None, + ) -> None: + """Add a counter event to the trace. + + Counter events are useful for tracking metrics over time (e.g., reward, + loss, batch size) and appear as line graphs in trace viewers. + + Args: + name: Name of the counter + value: Counter value + metadata: Optional additional metadata + """ + if not self._enabled: + return + self._ensure_tid() + + if metadata is None: + metadata = {} + metadata["tracer_name"] = self._name + metadata["value"] = value + + event = { + "name": name, + "ph": "C", # Counter event + "ts": int(time.monotonic() * 1_000_000), + "pid": self._pid, + "tid": self._tid, + "args": metadata, + } + + with self._events_lock: + self._events.append(event) + + def get_events(self) -> list[Event]: + """Get the accumulated trace events. + + Useful for programmatic analysis or custom export formats. + + Returns: + List of trace event dictionaries + """ + with self._events_lock: + return list(self._events) + + @property + def enabled(self) -> bool: + """Check if tracing is enabled.""" + return self._enabled + + +def tracing_enabled(): + return os.environ.get("NEMORL_TRACE_ENABLED", "0").lower() in ("1", "true", "yes") + + +def new_tracer(name: str = "") -> Tracer: + return Tracer(enabled=tracing_enabled(), name=name) + + +def define_collect_trace(get_tracer_events): + def collect_trace(self, timing: bool): + if timing: + return time.monotonic() + else: + return get_tracer_events(self) + return collect_trace + + +def save_trace(local_events: list[Event], actors: tuple[..., Any]): + if not tracing_enabled(): + return + + events = local_events + for actor in actors: + # Poor man's clock synchronization to account for actors running on different + # nodes. + ts_local = time.monotonic() + ts_actor = ray.get(actor.collect_trace.remote(timing=True)) + latency = (time.monotonic() - ts_local) / 2 + ts_delta = int((ts_actor - ts_local - latency) * 1_000_000) + + actor_events = ray.get(actor.collect_trace.remote(timing=False)) + for actor_event in actor_events: + actor_event["ts"] -= ts_delta + events.extend(actor_events) + + # Perfetto wants events to be sorted. Ensure that they are, even if we merged tracers. + events.sort(key=lambda event: event["ts"]) + + output_path = os.environ.get("NEMORL_TRACE_FILE", "nemorl_trace.json") + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(events, f, indent=2) + + print(f"Trace saved to: {output_path}") + print("View in Perfetto UI: https://ui.perfetto.dev") + print("Or open in Chrome: chrome://tracing") + print(f"Total events: {len(events)}") + + +@contextmanager +def trace_and_time( + tracer: Tracer, + timer: Timer, + span_name: str, + time_label: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, +): + time_label = time_label or span_name + with tracer.span(span_name, metadata), timer.time(time_label): + yield diff --git a/tests/unit/utils/test_trace.py b/tests/unit/utils/test_trace.py new file mode 100644 index 0000000000..1f6c20fadb --- /dev/null +++ b/tests/unit/utils/test_trace.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import time + +import pytest + +from nemo_rl.utils.trace import new_tracer, save_trace + + +class TestNemoRLTracer: + @pytest.fixture(autouse=True) + def mock_tracing_env(self, monkeypatch): + monkeypatch.setenv("NEMORL_TRACE_ENABLED", "1") + + def test_tracer_disabled(self, monkeypatch): + """Test that tracer is disabled when not explicitly enabled.""" + monkeypatch.delenv("NEMORL_TRACE_ENABLED", raising=False) + + tracer = new_tracer() + assert not tracer.enabled + + # Operations should be no-ops + with tracer.span("test"): + pass + + assert len(tracer.get_events()) == 0 + + def test_tracer_basic(self): + """Test that tracer captures events when enabled.""" + tracer = new_tracer() + assert tracer.enabled + + with tracer.span("test_span"): + pass + + events = tracer.get_events() + assert len(events) == 2 # Begin and end events + assert events[0]["ph"] == "B" + assert events[1]["ph"] == "E" + assert events[0]["name"] == "test_span" + assert events[1]["name"] == "test_span" + + def test_nested_spans(self): + """Test that nested spans are properly tracked.""" + tracer = new_tracer() + + with tracer.span("outer"): + with tracer.span("inner"): + pass + + events = tracer.get_events() + assert len(events) == 4 # 2 begin + 2 end events + assert events[0]["name"] == "outer" + assert events[1]["name"] == "inner" + assert events[2]["name"] == "inner" + assert events[3]["name"] == "outer" + + def test_span_with_metadata(self): + """Test that metadata is properly attached to spans.""" + tracer = new_tracer(name="foo") + + with tracer.span("test", metadata={"step": 5, "batch_size": 32}): + pass + + events = tracer.get_events() + begin_event = events[0] + assert "args" in begin_event + assert begin_event["args"]["tracer_name"] == "foo" + assert begin_event["args"]["step"] == 5 + assert begin_event["args"]["batch_size"] == 32 + + def test_explicit_start_end_span(self): + """Test explicit start_span/end_span calls.""" + tracer = new_tracer() + + tracer.start_span("phase1", metadata={"id": 1}) + # ... + tracer.end_span("phase1") + + events = tracer.get_events() + assert len(events) == 2 + assert events[0]["name"] == "phase1" + assert events[0]["args"]["id"] == 1 + + def test_mismatched_end_span_raises_error(self): + """Test that ending a span with wrong name raises error.""" + tracer = new_tracer() + + tracer.start_span("span1") + with pytest.raises(ValueError, match="Span name mismatch"): + tracer.end_span("span2") + + def test_end_span_without_start_raises_error(self): + """Test that ending a non-existent span raises error.""" + tracer = new_tracer() + + with pytest.raises(ValueError, match="No active span"): + tracer.end_span("nonexistent") + + def test_instant_event(self): + """Test instant event creation.""" + tracer = new_tracer() + + tracer.add_instant_event("checkpoint_saved", metadata={"step": 100}) + + events = tracer.get_events() + assert len(events) == 1 + assert events[0]["ph"] == "i" + assert events[0]["name"] == "checkpoint_saved" + assert events[0]["args"]["step"] == 100 + + def test_counter_event(self): + """Test counter event creation.""" + tracer = new_tracer() + + tracer.add_counter("reward", 0.85, metadata={"step": 1}) + + events = tracer.get_events() + assert len(events) == 1 + assert events[0]["ph"] == "C" + assert events[0]["name"] == "reward" + assert events[0]["args"]["value"] == 0.85 + assert events[0]["args"]["step"] == 1 + + def test_save_trace_file(self, monkeypatch): + """Test saving trace to JSON file.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "test_trace.json") + monkeypatch.setenv("NEMORL_TRACE_FILE", output_path) + tracer = new_tracer() + + with tracer.span("test"): + pass + + save_trace(tracer.get_events(), actors=()) + assert os.path.exists(output_path) + + # Verify JSON is valid and contains events + with open(output_path, "r") as f: + data = json.load(f) + assert isinstance(data, list) + assert len(data) == 2 # Begin and end events From 9cbb426ba2ec4e6c77b84b8c388d93853543393c Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Wed, 4 Feb 2026 19:52:52 +0000 Subject: [PATCH 2/2] Remove leftover code --- nemo_rl/utils/trace.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nemo_rl/utils/trace.py b/nemo_rl/utils/trace.py index 3601b63afa..d33d2f14ed 100644 --- a/nemo_rl/utils/trace.py +++ b/nemo_rl/utils/trace.py @@ -46,11 +46,7 @@ from pathlib import Path from typing import Any, Generator, Optional -try: - import ray - RAY_AVAILABLE = True -except ImportError: - RAY_AVAILABLE = False +import ray from nemo_rl.utils.timer import Timer