feat: add worker initialization timing collection#1873
feat: add worker initialization timing collection#1873yashaswikarnati wants to merge 10 commits intomainfrom
Conversation
Add config flag and collection code to aggregate and save worker init timing to JSON
📝 WalkthroughWalkthroughThis pull request adds initialization timing collection infrastructure throughout the worker initialization pipeline. It introduces Timer utility methods for persistence and aggregation, instruments worker initialization in vLLM and Megatron components, enables collection across distributed worker groups, and integrates optional logging of aggregated timings in the GRPO algorithm. Changes
Sequence Diagram(s)sequenceDiagram
participant GRPO as GRPO Algorithm
participant RayWG as RayWorkerGroup
participant Worker as MegatronPolicyWorker
participant Timer as Timer
participant Logger as Logger
participant FileSystem as FileSystem
GRPO->>RayWG: collect_init_timing()
RayWG->>Worker: get_init_timing.remote()
Note over Worker: Returns cached<br/>init_timer
Worker-->>RayWG: init_timer
RayWG->>Timer: aggregate_max(timers)
Timer-->>RayWG: aggregated timings dict
RayWG-->>GRPO: aggregated timings
alt logger.collect_worker_init_timing enabled
GRPO->>Timer: save_to_json(filepath, metadata)
Timer->>FileSystem: create directories
Timer->>FileSystem: write JSON file
FileSystem-->>Timer: success
GRPO->>Logger: print confirmation
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Important Action Needed: IP Allowlist UpdateIf your organization protects your Git platform with IP whitelisting, please add the new CodeRabbit IP address to your allowlist:
Reviews will stop working after February 8, 2026 if the new IP is not added to your allowlist. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/utils/logger.py (1)
77-90:⚠️ Potential issue | 🟠 MajorDocument
collect_worker_init_timingconfig key and add YAML default.Line 90 introduces a new LoggerConfig key without documenting its purpose, valid values, or recommended default. Additionally, exemplar YAML files under
examples/configs/do not include this key. Per coding guidelines, document the key's purpose and type (currently unclear; appears to control worker initialization timing collection), specify the recommended default (suggestfalsebased on current usage), and add it to exemplar YAMLs likeexamples/configs/grpo_math_1B.yaml.
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 658-663: The code assumes every worker in
policy.worker_group.workers implements get_init_timing; change the collection to
first filter workers using hasattr(worker, "get_init_timing") and only call
get_init_timing.remote() on those that have it (e.g., build the list
comprehension from workers where hasattr is true), then call ray.get on that
list and assign to worker_timers; also handle the case where the filtered list
is empty by setting worker_timers to an empty list to avoid passing an
empty/invalid remote call list.
In `@nemo_rl/distributed/worker_groups.py`:
- Around line 1041-1065: The current broad except in collect_init_timing hides
Ray and typing failures; narrow it to specific errors (e.g.,
ray.exceptions.RayError and relevant subclasses like RayActorError/RayTaskError
plus TypeError/ValueError) around the ray.get/remote and Timer.aggregate_max
calls, log the caught exception (using self._logger.exception(...) if a logger
is available or logging.exception(...)) and then return {} as before; reference
collect_init_timing, worker.get_init_timing.remote(), ray.get(...), and
Timer.aggregate_max(...) when making the change.
In `@nemo_rl/models/generation/vllm/vllm_worker.py`:
- Around line 441-447: get_init_timing currently returns a dict from
self.init_timer.get_timing_metrics(reduction_op="sum") but upstream aggregation
expects Timer objects; change get_init_timing to return the Timer instance
(self.init_timer) instead of a dict so Timer.aggregate_max and related
aggregation in worker_groups.py/grpo.py work correctly. Locate the
get_init_timing method in vllm_worker.py and return self.init_timer (same
pattern as in megatron_policy_worker.py) so callers receive a Timer object
compatible with Timer.aggregate_max.
In `@nemo_rl/models/policy/workers/megatron_policy_worker.py`:
- Around line 15-19: Rename the module-level timing globals to use the G_ prefix
and upper snake_case (e.g., change _module_import_start_time to
G_MODULE_IMPORT_START_TIME) and apply the same convention to the other timing
globals referenced around lines 112-114 and 203-205; update every use of these
names within the module to the new identifiers to avoid breakage and ensure they
remain module-scoped globals.
In `@nemo_rl/utils/timer.py`:
- Around line 237-271: The JSON dump fails because get_timing_metrics returns
NumPy scalar types (e.g., np.float64) which json.dump cannot serialize; in
save_to_json convert all timing metric values to native Python types before
writing: locate save_to_json and after timing_metrics =
self.get_timing_metrics(...) walk the timing_metrics dict (and any nested
structures) and coerce NumPy floats to float(), NumPy ints/counts to int(), and
NumPy arrays/ndarrays to lists (or pick the appropriate scalar conversion) so
the resulting output dict contains only built-in types before calling json.dump.
🧹 Nitpick comments (1)
nemo_rl/algorithms/grpo.py (1)
673-675: Consider adding a public API to Timer for constructing from aggregated data.Directly assigning to
aggregated_timer._timersat line 674 bypasses the Timer class's public interface. While this pattern is used elsewhere in the codebase (e.g.,megatron_policy_worker.py), a public factory method would improve encapsulation.Consider adding a classmethod to Timer:
`@classmethod` def from_aggregated(cls, timing_dict: dict) -> "Timer": """Create a Timer from pre-aggregated timing data.""" timer = cls() timer._timers = {k: [v] for k, v in timing_dict.items()} return timerThen replace lines 673–674 with:
aggregated_timer = Timer.from_aggregated(max_timing)
nemo_rl/distributed/worker_groups.py
Outdated
| try: | ||
| # Check if workers support get_init_timing | ||
| if not self._workers: | ||
| return {} | ||
|
|
||
| # Collect timing from all workers | ||
| timing_futures = [] | ||
| for worker in self._workers: | ||
| if hasattr(worker, "get_init_timing"): | ||
| timing_futures.append(worker.get_init_timing.remote()) | ||
|
|
||
| if not timing_futures: | ||
| return {} | ||
|
|
||
| # Get all timers | ||
| timers = ray.get(timing_futures) | ||
|
|
||
| # Aggregate using max across workers, sum within each worker | ||
| aggregated = Timer.aggregate_max(timers, reduction_op="sum") | ||
|
|
||
| return aggregated | ||
|
|
||
| except Exception: | ||
| # Return empty dict on any error | ||
| return {} |
There was a problem hiding this comment.
Narrow the exception handling in collect_init_timing.
Catching bare Exception hides Ray failures and makes debugging hard. Please limit to likely Ray/typing errors and log the exception.
🔧 Suggested adjustment
- except Exception:
- # Return empty dict on any error
- return {}
+ except (ray.exceptions.RayError, AttributeError, TypeError) as exc:
+ print(f"Error collecting init timing: {exc}")
+ return {}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| # Check if workers support get_init_timing | |
| if not self._workers: | |
| return {} | |
| # Collect timing from all workers | |
| timing_futures = [] | |
| for worker in self._workers: | |
| if hasattr(worker, "get_init_timing"): | |
| timing_futures.append(worker.get_init_timing.remote()) | |
| if not timing_futures: | |
| return {} | |
| # Get all timers | |
| timers = ray.get(timing_futures) | |
| # Aggregate using max across workers, sum within each worker | |
| aggregated = Timer.aggregate_max(timers, reduction_op="sum") | |
| return aggregated | |
| except Exception: | |
| # Return empty dict on any error | |
| return {} | |
| try: | |
| # Check if workers support get_init_timing | |
| if not self._workers: | |
| return {} | |
| # Collect timing from all workers | |
| timing_futures = [] | |
| for worker in self._workers: | |
| if hasattr(worker, "get_init_timing"): | |
| timing_futures.append(worker.get_init_timing.remote()) | |
| if not timing_futures: | |
| return {} | |
| # Get all timers | |
| timers = ray.get(timing_futures) | |
| # Aggregate using max across workers, sum within each worker | |
| aggregated = Timer.aggregate_max(timers, reduction_op="sum") | |
| return aggregated | |
| except (ray.exceptions.RayError, AttributeError, TypeError) as exc: | |
| print(f"Error collecting init timing: {exc}") | |
| return {} |
🤖 Prompt for AI Agents
In `@nemo_rl/distributed/worker_groups.py` around lines 1041 - 1065, The current
broad except in collect_init_timing hides Ray and typing failures; narrow it to
specific errors (e.g., ray.exceptions.RayError and relevant subclasses like
RayActorError/RayTaskError plus TypeError/ValueError) around the ray.get/remote
and Timer.aggregate_max calls, log the caught exception (using
self._logger.exception(...) if a logger is available or logging.exception(...))
and then return {} as before; reference collect_init_timing,
worker.get_init_timing.remote(), ray.get(...), and Timer.aggregate_max(...) when
making the change.
| def get_init_timing(self) -> dict[str, float]: | ||
| """Get initialization timing metrics. | ||
|
|
||
| Returns: | ||
| Dictionary mapping timing labels to elapsed times in seconds | ||
| """ | ||
| return self.init_timer.get_timing_metrics(reduction_op="sum") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check the current state of the method
cat -n nemo_rl/models/generation/vllm/vllm_worker.py | sed -n '441,447p'Repository: NVIDIA-NeMo/RL
Length of output: 376
🏁 Script executed:
# Search for all usages of get_init_timing
rg -n "get_init_timing" -C 3Repository: NVIDIA-NeMo/RL
Length of output: 2943
🏁 Script executed:
# Search for collect_init_timing and aggregation patterns
rg -n "collect_init_timing|aggregate_max" -C 3Repository: NVIDIA-NeMo/RL
Length of output: 6943
🏁 Script executed:
# Check the Timer class definition and its methods
rg -n "class Timer" -C 5Repository: NVIDIA-NeMo/RL
Length of output: 646
🏁 Script executed:
# Search for how init_timer is aggregated across workers
rg -n "init_timer" -C 2Repository: NVIDIA-NeMo/RL
Length of output: 6504
Return Timer object instead of dict from get_init_timing.
Line 447 returns a dict from get_timing_metrics(), but the aggregation pipeline in worker_groups.py and grpo.py expects Timer objects. The call to Timer.aggregate_max(timers, ...) will fail when receiving dicts instead of Timer instances. The megatron_policy_worker.py correctly returns the Timer object directly; vllm_worker.py should do the same.
Proposed fix
- def get_init_timing(self) -> dict[str, float]:
- """Get initialization timing metrics.
-
- Returns:
- Dictionary mapping timing labels to elapsed times in seconds
- """
- return self.init_timer.get_timing_metrics(reduction_op="sum")
+ def get_init_timing(self) -> Timer:
+ """Return init timing for controller aggregation."""
+ return self.init_timer🤖 Prompt for AI Agents
In `@nemo_rl/models/generation/vllm/vllm_worker.py` around lines 441 - 447,
get_init_timing currently returns a dict from
self.init_timer.get_timing_metrics(reduction_op="sum") but upstream aggregation
expects Timer objects; change get_init_timing to return the Timer instance
(self.init_timer) instead of a dict so Timer.aggregate_max and related
aggregation in worker_groups.py/grpo.py work correctly. Locate the
get_init_timing method in vllm_worker.py and return self.init_timer (same
pattern as in megatron_policy_worker.py) so callers receive a Timer object
compatible with Timer.aggregate_max.
| # Measure module import time (import time first for measurement) | ||
| import time | ||
|
|
||
| _module_import_start_time = time.perf_counter() | ||
|
|
There was a problem hiding this comment.
Rename module-level timing globals to G_ upper snake_case.*
Module-level globals should use the G_ prefix and upper snake_case.
🔤 Suggested rename
-_module_import_start_time = time.perf_counter()
+G_MODULE_IMPORT_START_TIME = time.perf_counter()
...
-_module_import_duration = time.perf_counter() - _module_import_start_time
+G_MODULE_IMPORT_DURATION = time.perf_counter() - G_MODULE_IMPORT_START_TIME
...
- if "_module_import_duration" in globals():
- self.init_timer._timers["module_import"] = [_module_import_duration]
+ if "G_MODULE_IMPORT_DURATION" in globals():
+ self.init_timer._timers["module_import"] = [G_MODULE_IMPORT_DURATION]Also applies to: 112-114, 203-205
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/megatron_policy_worker.py` around lines 15 -
19, Rename the module-level timing globals to use the G_ prefix and upper
snake_case (e.g., change _module_import_start_time to
G_MODULE_IMPORT_START_TIME) and apply the same convention to the other timing
globals referenced around lines 112-114 and 203-205; update every use of these
names within the module to the new identifiers to avoid breakage and ensure they
remain module-scoped globals.
nemo_rl/utils/timer.py
Outdated
| def save_to_json( | ||
| self, | ||
| filepath: Union[str, Path], | ||
| reduction_op: str = "sum", | ||
| metadata: Optional[dict] = None, | ||
| ) -> None: | ||
| """Save timing measurements to a JSON file. | ||
|
|
||
| Args: | ||
| filepath: Path where the JSON file will be saved | ||
| reduction_op: Reduction operation to apply to all timing measurements. | ||
| Valid options are: "mean", "median", "min", "max", "std", "sum", "count" | ||
| metadata: Optional dictionary of metadata to include in the JSON file | ||
|
|
||
| Raises: | ||
| ValueError: If an invalid reduction operation is provided | ||
| """ | ||
| filepath = Path(filepath) | ||
|
|
||
| # Get timing metrics with the specified reduction | ||
| timing_metrics = self.get_timing_metrics(reduction_op=reduction_op) | ||
|
|
||
| # Build the output dictionary | ||
| output = { | ||
| "timings": timing_metrics, | ||
| "reduction_op": reduction_op, | ||
| } | ||
|
|
||
| if metadata is not None: | ||
| output["metadata"] = metadata | ||
|
|
||
| # Write to JSON file | ||
| filepath.parent.mkdir(parents=True, exist_ok=True) | ||
| with open(filepath, "w") as f: | ||
| json.dump(output, f, indent=2) |
There was a problem hiding this comment.
Convert NumPy scalars before JSON serialization.
get_timing_metrics() uses NumPy reductions, which return np.float64 values. json.dump can’t serialize these and will raise TypeError at runtime. Coerce to built-in types before dumping.
🛠️ Proposed fix
# Get timing metrics with the specified reduction
timing_metrics = self.get_timing_metrics(reduction_op=reduction_op)
+ def _to_jsonable(value: object) -> object:
+ if isinstance(value, np.generic):
+ return value.item()
+ if isinstance(value, list):
+ return [_to_jsonable(v) for v in value]
+ return value
+
+ timing_metrics = {k: _to_jsonable(v) for k, v in timing_metrics.items()}
+
# Build the output dictionary
output = {
"timings": timing_metrics,
"reduction_op": reduction_op,
}🤖 Prompt for AI Agents
In `@nemo_rl/utils/timer.py` around lines 237 - 271, The JSON dump fails because
get_timing_metrics returns NumPy scalar types (e.g., np.float64) which json.dump
cannot serialize; in save_to_json convert all timing metric values to native
Python types before writing: locate save_to_json and after timing_metrics =
self.get_timing_metrics(...) walk the timing_metrics dict (and any nested
structures) and coerce NumPy floats to float(), NumPy ints/counts to int(), and
NumPy arrays/ndarrays to lists (or pick the appropriate scalar conversion) so
the resulting output dict contains only built-in types before calling json.dump.
ℹ️ File Consistency CheckCheck based on commit: b678d6d (PR #1873 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
b678d6d to
d45b2a5
Compare
ℹ️ File Consistency CheckCheck based on commit: d45b2a5 (PR #1873 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 59c2193 (PR #1873 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: ykarnati <[email protected]>
ℹ️ File Consistency CheckCheck based on commit: 5233a32 (PR #1873 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: ykarnati <[email protected]>
ℹ️ File Consistency CheckCheck based on commit: a50defa (PR #1873 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Add config flag and collection code to aggregate and save worker init timing to JSON
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests