feat: Add perfetto tracing for async GRPO training#1876
feat: Add perfetto tracing for async GRPO training#1876gspschmid wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
580a4a5 to
0fa44fc
Compare
📝 WalkthroughWalkthroughThe changes introduce comprehensive tracing instrumentation to the async utilities and GRPO training algorithms. A new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Important Action Needed: IP Allowlist UpdateIf your organization protects your Git platform with IP whitelisting, please add the new CodeRabbit IP address to your allowlist:
Reviews will stop working after February 8, 2026 if the new IP is not added to your allowlist. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/async_utils.py`:
- Around line 308-318: The code currently always creates/accumulates worker
tracers causing memory growth; change initialization and usage of
_worker_tracers to be conditional on tracing being enabled: initialize
self._worker_tracers as an empty list only if self._tracer.enabled (otherwise
set to None or skip), and update any code that creates/appends worker tracers to
check self._tracer.enabled before creating/appending; ensure collect_trace
(method collect_trace) handles the disabled case by skipping iterating over
_worker_tracers when tracing is off and still returns events from the main
tracers.
In `@nemo_rl/algorithms/grpo.py`:
- Around line 2516-2549: Replace the manual tracer.start_span("training") /
tracer.end_span("training") pair with a context-manager form so spans are always
closed on exceptions: wrap the entire training block (the code that calls
policy.prepare_for_lp_inference(), policy.get_logprobs(...),
policy.get_reference_policy_logprobs(...), policy.prepare_for_training(), and
policy.train(...)) in a single with tracer.span("training"): block instead of
start_span/end_span; do the same for the validation block (the block that
currently uses tracer.start_span("validation")/tracer.end_span("validation")) so
both spans use with tracer.span("..."):. Ensure the code nested inside remains
unchanged and indented under the with blocks so the span is properly closed on
exception.
In `@nemo_rl/utils/trace.py`:
- Line 1: The file header in nemo_rl/utils/trace.py still shows "Copyright (c)
2025, NVIDIA CORPORATION." — update this top-of-file copyright year to 2026 so
the header reads "Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved."
and ensure the updated header appears at the very top of the file (affecting the
file that defines tracing utilities in this module).
- Around line 320-321: Remove the unnecessary f-string prefixes on the two print
statements that output Perfetto/Chrome tracing links in nemo_rl/utils/trace.py:
replace print(f"View in Perfetto UI: https://ui.perfetto.dev") and print(f"Or
open in Chrome: chrome://tracing") with plain string literals (print("View in
Perfetto UI: https://ui.perfetto.dev") and print("Or open in Chrome:
chrome://tracing")) to satisfy Ruff F541; update both occurrences and run the
linter to confirm the warning is cleared.
- Line 144: Unpack the tuple from self._span_stack.pop() using an
underscore-prefixed name for the unused value: change the current unpacking
"span_name, span_start, _span_metadata = self._span_stack.pop()" to use
"_span_start" instead of "span_start" so the unused variable follows Python
convention and suppresses the Ruff warning; ensure no other references to
span_start exist in the method before committing.
🧹 Nitpick comments (5)
nemo_rl/utils/trace.py (2)
49-53: Rename globalRAY_AVAILABLEto use theG_prefix.This aligns the global with the required naming convention (and update any references accordingly).
🔧 Suggested fix
try: import ray - RAY_AVAILABLE = True + G_RAY_AVAILABLE = True except ImportError: - RAY_AVAILABLE = False + G_RAY_AVAILABLE = FalseAs per coding guidelines, Use upper snake_case with
Gprefix for global variables, e.g.,G_MY_GLOBAL.
276-335: Add Google-style docstrings for public tracing helpers.
tracing_enabled,new_tracer,define_collect_trace,save_trace, andtrace_and_timeare public helpers but currently lack docstrings.✍️ Example (apply similarly to the other helpers)
def tracing_enabled(): + """Check whether tracing is enabled via environment variables. + + Returns: + True if NEMORL_TRACE_ENABLED is truthy, else False. + """ return os.environ.get("NEMORL_TRACE_ENABLED", "0").lower() in ("1", "true", "yes")As per coding guidelines, Use Google style docstrings for classes and functions.
nemo_rl/algorithms/async_utils.py (2)
58-60: Documentcollect_traceactor APIs.These methods are called remotely (e.g., during trace aggregation) and should carry Google-style docstrings.
✍️ Suggested update
`@define_collect_trace` def collect_trace(self): + """Collect tracer events for Perfetto export.""" return self._tracer.get_events() @@ `@define_collect_trace` def collect_trace(self): + """Collect tracer events for Perfetto export.""" events = self._tracer.get_events() events.extend(self._loop_tracer.get_events()) for worker_tracer in self._worker_tracers: events.extend(worker_tracer.get_events()) return eventsAs per coding guidelines, Use Google style docstrings for classes and functions.
Also applies to: 312-318
635-655: Avoid hidden config defaults and narrow the exception type.Use explicit config values (no implicit
Falsedefaults) and catch only expected failures from cache invalidation.🛠️ Suggested fix
- async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {}) - if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get( - "recompute_kv_cache_after_weight_updates", False - ): + async_cfg = self.master_config["grpo"]["async_grpo"] + if async_cfg.get("in_flight_weight_updates") and async_cfg.get( + "recompute_kv_cache_after_weight_updates" + ): try: print("🔄 Invalidating vLLM prefix/KV caches after weight update") invalidated = self.policy_generation.invalidate_kv_cache() if invalidated: print("✅ Invalidated vLLM prefix/KV caches after weight update") else: print( "⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers" ) - except Exception as e: + except RuntimeError as e: print(f"⚠️ Failed to invalidate vLLM caches: {e}")As per coding guidelines, YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values, and In try-except blocks, limit the except clause to the smallest set of errors possible.
nemo_rl/algorithms/grpo.py (1)
2859-2864: Narrow the exception type when saving traces.Catching
Exceptionmasks unexpected failures; constrain it to the expected Ray/IO errors.🛠️ Suggested fix
- except Exception as e: + except (ray.exceptions.RayError, OSError) as e: print(f"Error saving tracer events: {e}")As per coding guidelines, In try-except blocks, limit the except clause to the smallest set of errors possible.
| tracer.start_span("training") | ||
| print("▶ Preparing for logprob inference...") | ||
| with timer.time("logprob_inference_prep"): | ||
| with trace_and_time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
|
|
||
| print("▶ Computing logprobs...") | ||
| with timer.time("policy_and_reference_logprobs"): | ||
| fprop_logprobs = policy.get_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["logprobs"] | ||
| reference_logprobs = policy.get_reference_policy_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["reference_logprobs"] | ||
| with trace_and_time("policy_and_reference_logprobs"): | ||
| with tracer.span("policy_logprobs"): | ||
| fprop_logprobs = policy.get_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["logprobs"] | ||
|
|
||
| with tracer.span("reference_policy_logprobs"): | ||
| reference_logprobs = policy.get_reference_policy_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["reference_logprobs"] | ||
| train_data["prev_logprobs"] = fprop_logprobs | ||
| train_data["reference_policy_logprobs"] = reference_logprobs | ||
|
|
||
| print("▶ Preparing for training...") | ||
| with timer.time("training_prep"): | ||
| with trace_and_time("training_prep"): | ||
| policy.prepare_for_training() | ||
| POLICY_GENERATION_STALE = True | ||
|
|
||
| print("▶ Training policy...") | ||
| with timer.time("policy_training"): | ||
| with trace_and_time("policy_training"): | ||
| train_results = policy.train( | ||
| train_data, | ||
| loss_fn, | ||
| timer=timer, | ||
| ) | ||
| tracer.end_span("training") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the file exists and get its size
wc -l nemo_rl/algorithms/grpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 89
🏁 Script executed:
# Now let's look at the code around lines 2516-2549
sed -n '2516,2549p' nemo_rl/algorithms/grpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 1600
🏁 Script executed:
# And the code around lines 2587-2623
sed -n '2587,2623p' nemo_rl/algorithms/grpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 1710
🏁 Script executed:
# Let's find where tracer is defined or imported
rg -n "tracer\s*=" nemo_rl/algorithms/grpo.py | head -20Repository: NVIDIA-NeMo/RL
Length of output: 88
🏁 Script executed:
# Search for where tracer's span method is called to understand its usage
rg -n "tracer\.span\(" nemo_rl/algorithms/grpo.py | head -10Repository: NVIDIA-NeMo/RL
Length of output: 235
🏁 Script executed:
# Search for tracer import
rg -n "import.*tracer|from.*tracer" nemo_rl/algorithms/grpo.pyRepository: NVIDIA-NeMo/RL
Length of output: 115
🏁 Script executed:
# Find the tracer implementation
find nemo_rl -name "trace.py" -o -name "tracer.py"Repository: NVIDIA-NeMo/RL
Length of output: 80
🏁 Script executed:
# Check if nemo_rl/utils/trace.py exists and read it
wc -l nemo_rl/utils/trace.pyRepository: NVIDIA-NeMo/RL
Length of output: 84
🏁 Script executed:
# Read the trace.py file
cat -n nemo_rl/utils/trace.pyRepository: NVIDIA-NeMo/RL
Length of output: 12541
Use context-manager spans to guarantee closure on exceptions.
The tracer.span() context manager ensures end_span() executes in a finally block, preventing unmatched spans if an exception bubbles up between start_span and end_span. Refactor the "training" span (lines 2516-2549) and "validation" span (lines 2587-2623) to use with tracer.span("name"): instead of manual start_span/end_span calls.
🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 2516 - 2549, Replace the manual
tracer.start_span("training") / tracer.end_span("training") pair with a
context-manager form so spans are always closed on exceptions: wrap the entire
training block (the code that calls policy.prepare_for_lp_inference(),
policy.get_logprobs(...), policy.get_reference_policy_logprobs(...),
policy.prepare_for_training(), and policy.train(...)) in a single with
tracer.span("training"): block instead of start_span/end_span; do the same for
the validation block (the block that currently uses
tracer.start_span("validation")/tracer.end_span("validation")) so both spans use
with tracer.span("..."):. Ensure the code nested inside remains unchanged and
indented under the with blocks so the span is properly closed on exception.
Signed-off-by: Georg Stefan Schmid <gschmid@nvidia.com>
0aaaf93 to
438f285
Compare
|
Fwiw, I noticed that we might also be able to insert additional traces in Ray's own timeline, though I'm not sure what trade-offs that would come with. From the Ray documentation (https://docs.ray.io/en/latest/ray-observability/user-guides/ray-tracing.html#tracing) it seems that the feature is deprecated and requires OpenTelemetry as an external dependency. |
|
Re 1. In terms of tracing output I think it would be relatively similar, i.e. a single perfetto trace that covers spans across various Ray actors. Annotation overhead (in terms of additional code) would be comparable as well: spans are introduced via context managers in either case (https://docs.ray.io/en/latest/ray-observability/user-guides/ray-tracing.html#custom-traces). I was initially wondering whether the existing Ray infrastructure might be more robust, but given the note on top of that documentation page ("Tracing is an Alpha feature and no longer under active development/being maintained. APIs are subject to change.") I am less inclined to investigate much more deeply. In any case, the motivation behind this PR is mostly to trace a handful of steps of a training jobs to inform which performance optimization we should focus on. Re 2. I agree that the less lines of code we touch (and indent), the better. Where we can use decorators we probably should, and indeed it might make sense to add a more general |
|
@youngeunkwon0405 please review |
|
Hi @gspschmid, thanks for your contribution. This will be a very useful feature. I have a few suggestions on this PR.
|
Adds high-level tracing of async GRPO in the driver process, the trajectory collector and the replay buffer. This provides a quick visual impression of the time each component of GRPO contributes and to what degree they are overlapped.
Issues
List issues that this PR closes (syntax):
(None)
Usage
The resulting

nemorl_trace.jsoncan be opened in the usual perfetto trace viewer, e.g. via https://ui.perfetto.dev/. Here's what a simple example looks like (timings not necessarily representative, since this was run on a toy example, grpo_math_1B-2n8g-async-1off):Before your PR is "Ready for review"
Pre checks:
cc @guyueh1
Summary by CodeRabbit