Skip to content

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Feb 5, 2026

What does this PR do ?

In nemo_gym environment, when generation_logprobs contains NaN, retry rollout.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Improved NeMo-Gym rollout robustness: rollout operations now automatically retry when NaN values are detected in generation log probabilities. A new configurable parameter controls the maximum number of retry attempts.
  • Tests

    • Added test coverage for the rollout retry mechanism.

@guyueh1 guyueh1 requested review from a team as code owners February 5, 2026 17:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

The changes introduce retry logic to NeMo-Gym rollout collection to handle NaN values in generation log probabilities. A new configuration field enables retrying rollouts up to a maximum count, with NaN detection in generation_logprobs triggering automatic retries until valid results are obtained or the retry limit is reached.

Changes

Cohort / File(s) Summary
Configuration Update
examples/nemo_gym/run_grpo_nemo_gym.py
Added rollout_max_retries_to_avoid_lp_nan keyword argument to NemoGymConfig initialization, sourced from configuration with a default value of 1.
Core Retry Logic
nemo_rl/environments/nemo_gym.py
Added rollout_max_retries_to_avoid_lp_nan configuration field to NemoGymConfig and reworked run_rollouts method to implement a while loop that retries rollouts when NaN values are detected in generation_logprobs, incrementing a trial counter and re-executing until valid results are obtained or max retries are reached.
Test Coverage
tests/unit/environments/test_nemo_gym.py
Introduced module-level constant NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN and invocation counter, added nemo_gym_with_patched_run_examples fixture that patches RolloutCollectionHelper.run_examples to inject NaN values and track calls, and added test_nemo_gym_rollout_max_retries_to_avoid_lp_nan test to verify retry behavior matches configuration.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major feature with core rollout logic changes but lacks documented test results or testing information in description. Add test execution results to PR description demonstrating all tests pass. Include regression testing documentation, particularly for convergence and numerical stability impacts.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the primary change: adding retry logic when generation_logprobs contains NaN, which is implemented across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/nemo_gym/run_grpo_nemo_gym.py (1)

252-257: ⚠️ Potential issue | 🟠 Major

Avoid code-level default for rollout_max_retries_to_avoid_lp_nan.

Using cfg.get(..., 1) introduces a hidden default in code; please read the key directly and define the default in YAML.

🔧 Suggested fix
-        rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get("rollout_max_retries_to_avoid_lp_nan", 1),
+        rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg["rollout_max_retries_to_avoid_lp_nan"],

As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values" and "Access required config values directly (e.g., policy_cfg['precision']) and assume they are present; do not introduce hidden defaults in code."

nemo_rl/environments/nemo_gym.py (1)

27-31: ⚠️ Potential issue | 🟠 Major

Remove the TypedDict default and mark the field as required.

rollout_max_retries_to_avoid_lp_nan: int = 1 sets a non-None default in code, which violates the guideline that YAML is the single source of truth for configuration defaults. In TypedDict, this pattern is also inconsistent with the codebase convention of using NotRequired[int] for optional fields. Since the field is always explicitly provided at instantiation sites (never relying on the class default), it should be declared as required without a default value.

🔧 Suggested fix
 class NemoGymConfig(TypedDict):
     model_name: str
     base_urls: List[str]
     initial_global_config_dict: Dict[str, Any]
-    rollout_max_retries_to_avoid_lp_nan: int = 1
+    rollout_max_retries_to_avoid_lp_nan: int
🤖 Fix all issues with AI agents
In `@nemo_rl/environments/nemo_gym.py`:
- Around line 115-152: Validate and clearly define the semantics of max_retries
before entering the loop: ensure cfg["rollout_max_retries_to_avoid_lp_nan"] is
an int >= 1 (or raise a ValueError) so nemo_gym_num_rows is always defined; keep
current semantics as "max attempts" by replacing the while trial < max_retries
loop with a for attempt in range(max_retries): or explicitly document that
max_retries is the total number of attempts, and remove the off‑by‑one
ambiguity; reference the variables max_retries, trial, the while trial <
max_retries loop, and nemo_gym_num_rows when making the check and adjustment.

In `@tests/unit/environments/test_nemo_gym.py`:
- Around line 206-208: Add a Google-style docstring to the pytest fixture
nemo_gym_with_patched_run_examples describing its purpose, parameters (if any)
and return value; place it immediately below the def
nemo_gym_with_patched_run_examples(...) line and follow Google style sections
(Args:, Returns:) and mention that it yields a nemo_gym instance with
RolloutCollectionHelper.run_examples patched for tests so readers and Sphinx can
parse it.
- Around line 247-295: The fixture currently calls context.__enter__ and yields
env but if actor creation or setup fails the function exits before calling
context.__exit__, leaking the patch; wrap the setup and yield in a try/finally
so context.__exit__ is always called: call context.__enter__ first, then create
config and env via NemoGym.options(...).remote and perform
ray.get(env.health_check.remote()) inside the try, yield env as before, and in
the finally ensure you call env.shutdown.remote() and ray.kill(env) only if env
was created, then call context.__exit__(None, None, None) to guarantee the
patch_run_examples context is reverted even on failures.
- Around line 46-47: The mutable module-global run_examples_called should follow
the project's naming and test-safety conventions: rename it to
G_RUN_EXAMPLES_CALLED (upper snake case with G_ prefix) and ensure it's reset
before each test to avoid cross-test leakage; update all references to
run_examples_called in this file (e.g., where it is incremented or asserted) to
G_RUN_EXAMPLES_CALLED and add a pytest fixture or test setup that sets
G_RUN_EXAMPLES_CALLED = 0 before each test runs.
- Around line 222-239: The patched new_run_examples is unpacking awaitables
returned by orig_run_examples and yielding tuples, causing await task to fail in
run_rollouts; instead, for each awaitable "task" from orig_run_examples(self,
examples, head_server_config) create and yield a new async wrapper coroutine (or
future) that awaits the original task, injects the NaN into
result["response"]["output"] (preserve the has_generation_log_probs check and
raise ValueError if none), and then returns the (row, result) pair—i.e., leave
orig_run_examples and run_rollouts semantics intact by yielding an awaitable
that performs the mutation after awaiting the original "task".

Comment on lines 115 to 152
timer.start("_run_rollouts_total")
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task

with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)

nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
while trial < max_retries:

nemo_gym_num_rows = len(nemo_gym_examples)
nemo_gym_result_iterator = self.rch.run_examples(
examples=nemo_gym_examples, head_server_config=self.head_server_config
)

nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task

with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)

nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)


# determine if generation_logprobs contain NaN; if not, break;
logprob_contains_nan = False
for nemo_rl_result in nemo_rl_results:
for message in nemo_rl_result["message_log"]:
if "generation_logprobs" in message and message["generation_logprobs"] is not None:
if torch.isnan(message["generation_logprobs"]).any():
logprob_contains_nan = True
break
if logprob_contains_nan:
trial += 1
print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})")
continue
else:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify retry semantics and guard against non‑positive values.

while trial < max_retries treats the setting as max attempts; if users interpret it as “retries,” it’s off-by-one. Also, max_retries <= 0 leaves nemo_gym_num_rows undefined. Please validate and clearly define semantics.

🔧 Suggested fix (validation; keep current semantics)
-        max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
-        while trial < max_retries:
+        max_retries = self.cfg["rollout_max_retries_to_avoid_lp_nan"]
+        if max_retries < 1:
+            raise ValueError("rollout_max_retries_to_avoid_lp_nan must be >= 1")
+        trial = 0
+        while trial < max_retries:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
timer.start("_run_rollouts_total")
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task
with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
while trial < max_retries:
nemo_gym_num_rows = len(nemo_gym_examples)
nemo_gym_result_iterator = self.rch.run_examples(
examples=nemo_gym_examples, head_server_config=self.head_server_config
)
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task
with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)
# determine if generation_logprobs contain NaN; if not, break;
logprob_contains_nan = False
for nemo_rl_result in nemo_rl_results:
for message in nemo_rl_result["message_log"]:
if "generation_logprobs" in message and message["generation_logprobs"] is not None:
if torch.isnan(message["generation_logprobs"]).any():
logprob_contains_nan = True
break
if logprob_contains_nan:
trial += 1
print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})")
continue
else:
break
timer.start("_run_rollouts_total")
max_retries = self.cfg["rollout_max_retries_to_avoid_lp_nan"]
if max_retries < 1:
raise ValueError("rollout_max_retries_to_avoid_lp_nan must be >= 1")
trial = 0
while trial < max_retries:
nemo_gym_num_rows = len(nemo_gym_examples)
nemo_gym_result_iterator = self.rch.run_examples(
examples=nemo_gym_examples, head_server_config=self.head_server_config
)
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task
with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)
# determine if generation_logprobs contain NaN; if not, break;
logprob_contains_nan = False
for nemo_rl_result in nemo_rl_results:
for message in nemo_rl_result["message_log"]:
if "generation_logprobs" in message and message["generation_logprobs"] is not None:
if torch.isnan(message["generation_logprobs"]).any():
logprob_contains_nan = True
break
if logprob_contains_nan:
trial += 1
print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})")
continue
else:
break
🤖 Prompt for AI Agents
In `@nemo_rl/environments/nemo_gym.py` around lines 115 - 152, Validate and
clearly define the semantics of max_retries before entering the loop: ensure
cfg["rollout_max_retries_to_avoid_lp_nan"] is an int >= 1 (or raise a
ValueError) so nemo_gym_num_rows is always defined; keep current semantics as
"max attempts" by replacing the while trial < max_retries loop with a for
attempt in range(max_retries): or explicitly document that max_retries is the
total number of attempts, and remove the off‑by‑one ambiguity; reference the
variables max_retries, trial, the while trial < max_retries loop, and
nemo_gym_num_rows when making the check and adjustment.

Comment on lines 46 to 47
NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3
run_examples_called = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use G_ prefix for mutable global counter and reset per test.

run_examples_called is a mutable module global; rename to the required G_-prefixed upper snake case and reset it in the fixture to prevent cross‑test leakage.

🔧 Suggested fix
 NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3
-run_examples_called = 0
+G_RUN_EXAMPLES_CALLED = 0
@@
-            global run_examples_called
-            run_examples_called += 1
+            global G_RUN_EXAMPLES_CALLED
+            G_RUN_EXAMPLES_CALLED += 1
@@
-    context = patch_run_examples()
+    global G_RUN_EXAMPLES_CALLED
+    G_RUN_EXAMPLES_CALLED = 0
+    context = patch_run_examples()
@@
-    assert run_examples_called == NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN
+    assert G_RUN_EXAMPLES_CALLED == NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN

As per coding guidelines, "Use upper snake_case with G prefix for global variables, e.g., G_MY_GLOBAL."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3
run_examples_called = 0
NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3
G_RUN_EXAMPLES_CALLED = 0
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 46 - 47, The mutable
module-global run_examples_called should follow the project's naming and
test-safety conventions: rename it to G_RUN_EXAMPLES_CALLED (upper snake case
with G_ prefix) and ensure it's reset before each test to avoid cross-test
leakage; update all references to run_examples_called in this file (e.g., where
it is incremented or asserted) to G_RUN_EXAMPLES_CALLED and add a pytest fixture
or test setup that sets G_RUN_EXAMPLES_CALLED = 0 before each test runs.

Comment on lines 206 to 208
@pytest.fixture(scope="function")
def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation):
from nemo_gym.rollout_collection import RolloutCollectionHelper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add Google‑style docstring to the new fixture.

The new fixture lacks a docstring; please add a Google‑style docstring for clarity and consistency.

✍️ Suggested fix
 def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation):
+    """Fixture that patches RolloutCollectionHelper.run_examples to inject NaN logprobs."""

As per coding guidelines, "Use Google style docstrings for classes and functions, which can be parsed by Sphinx."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.fixture(scope="function")
def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation):
from nemo_gym.rollout_collection import RolloutCollectionHelper
`@pytest.fixture`(scope="function")
def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation):
"""Fixture that patches RolloutCollectionHelper.run_examples to inject NaN logprobs."""
from nemo_gym.rollout_collection import RolloutCollectionHelper
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 206 - 208, Add a
Google-style docstring to the pytest fixture nemo_gym_with_patched_run_examples
describing its purpose, parameters (if any) and return value; place it
immediately below the def nemo_gym_with_patched_run_examples(...) line and
follow Google style sections (Args:, Returns:) and mention that it yields a
nemo_gym instance with RolloutCollectionHelper.run_examples patched for tests so
readers and Sphinx can parse it.

Comment on lines 222 to 239
def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]:
print("[guyueh debug] calling new_run_examples: ")
global run_examples_called
run_examples_called += 1

for task in orig_run_examples(self, examples, head_server_config):
row, result = task
# insert a NaN in the generation_log_probs
has_generation_log_probs = False
for i in range(len(result["response"]["output"])):
if "generation_log_probs" in result["response"]["output"][i]:
result["response"]["output"][i]["generation_log_probs"][-1] = float('nan')
has_generation_log_probs = True
break

if not has_generation_log_probs:
raise ValueError("No generation_log_probs found in the result")
yield row, result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's find the test file and understand its context
find . -name "test_nemo_gym.py" -type f

Repository: NVIDIA-NeMo/RL

Length of output: 100


🏁 Script executed:

# Let's look at the structure of the test file around lines 222-239
head -n 300 tests/unit/environments/test_nemo_gym.py | tail -n 100

Repository: NVIDIA-NeMo/RL

Length of output: 3698


🏁 Script executed:

# Search for the definition of orig_run_examples and how it's used
rg "orig_run_examples" -B 3 -A 3

Repository: NVIDIA-NeMo/RL

Length of output: 1854


🏁 Script executed:

# Search for RolloutCollectionHelper to understand run_examples behavior
rg "RolloutCollectionHelper" -B 2 -A 5

Repository: NVIDIA-NeMo/RL

Length of output: 3429


🏁 Script executed:

# Look for run_rollouts to understand how it awaits tasks
rg "run_rollouts" -B 3 -A 10

Repository: NVIDIA-NeMo/RL

Length of output: 6332


🏁 Script executed:

# Get more context around run_rollouts to see how the iterator is consumed
sed -n '105,200p' nemo_rl/environments/nemo_gym.py

Repository: NVIDIA-NeMo/RL

Length of output: 4423


🏁 Script executed:

# Look for await or asyncio patterns in nemo_gym.py
rg "await|asyncio" nemo_rl/environments/nemo_gym.py -B 2 -A 2

Repository: NVIDIA-NeMo/RL

Length of output: 354


Patched run_examples must yield awaitables, not tuples.

The run_rollouts method awaits each task returned from run_examples (line: nemo_gym_row, nemo_gym_result = await task). The patched method unpacks the awaitable from orig_run_examples and yields tuples instead, which will fail at runtime when await task is called on a tuple.

Wrap the awaitable and inject NaN after awaiting:

Suggested fix
-        def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]:
+        def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]:
             print("[guyueh debug] calling new_run_examples: ")
             global run_examples_called
             run_examples_called += 1
-
-            for task in orig_run_examples(self, examples, head_server_config):
-                row, result = task
-                # insert a NaN in the generation_log_probs
-                has_generation_log_probs = False
-                for i in range(len(result["response"]["output"])):
-                    if "generation_log_probs" in result["response"]["output"][i]:
-                        result["response"]["output"][i]["generation_log_probs"][-1] = float('nan')
-                        has_generation_log_probs = True
-                        break
-
-                if not has_generation_log_probs:
-                    raise ValueError("No generation_log_probs found in the result")
-                yield row, result
+            async def _wrap(task: Future):
+                row, result = await task
+                # insert a NaN in the generation_log_probs
+                has_generation_log_probs = False
+                for i in range(len(result["response"]["output"])):
+                    if "generation_log_probs" in result["response"]["output"][i]:
+                        result["response"]["output"][i]["generation_log_probs"][-1] = float("nan")
+                        has_generation_log_probs = True
+                        break
+
+                if not has_generation_log_probs:
+                    raise ValueError("No generation_log_probs found in the result")
+                return row, result
+
+            for task in orig_run_examples(self, examples, head_server_config):
+                yield _wrap(task)
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 238-238: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 222 - 239, The patched
new_run_examples is unpacking awaitables returned by orig_run_examples and
yielding tuples, causing await task to fail in run_rollouts; instead, for each
awaitable "task" from orig_run_examples(self, examples, head_server_config)
create and yield a new async wrapper coroutine (or future) that awaits the
original task, injects the NaN into result["response"]["output"] (preserve the
has_generation_log_probs check and raise ValueError if none), and then returns
the (row, result) pair—i.e., leave orig_run_examples and run_rollouts semantics
intact by yielding an awaitable that performs the mutation after awaiting the
original "task".

Comment on lines 247 to 295
context = patch_run_examples()
context.__enter__()

yaml_str = r"""example_multi_step_resources_server:
resources_servers:
example_multi_step:
entrypoint: app.py
domain: instruction_following
example_multi_step_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: example_multi_step_resources_server
model_server:
type: responses_api_models
name: openai_model
openai_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: ${policy_base_url}
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
"""

config = NemoGymConfig(
model_name=nemo_gym_vllm_generation.cfg["model_name"],
base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls,
initial_global_config_dict=safe_load(yaml_str),
rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN,
)
env = NemoGym.options(
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.nemo_gym.NemoGym"
),
}
).remote(config)
ray.get(env.health_check.remote())
yield env
env.shutdown.remote()
ray.kill(env)
time.sleep(0.1)
context.__exit__(None, None, None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guarantee the run_examples patch is reverted even if fixture setup fails.

If actor creation fails before teardown, context.__exit__ is never called and the patch can leak into other tests. Wrap the fixture body in try/finally to ensure cleanup.

🧹 Suggested fix
-    context = patch_run_examples()
-    context.__enter__()
+    context = patch_run_examples()
+    context.__enter__()
+    env = None
     
     yaml_str = r"""example_multi_step_resources_server:
@@
-    env = NemoGym.options(
-        runtime_env={
-            "py_executable": get_actor_python_env(
-                "nemo_rl.environments.nemo_gym.NemoGym"
-                ),
-            }
-        ).remote(config)
-    ray.get(env.health_check.remote())
-    yield env
-    env.shutdown.remote()
-    ray.kill(env)
-    time.sleep(0.1)
-    context.__exit__(None, None, None)
+    try:
+        env = NemoGym.options(
+            runtime_env={
+                "py_executable": get_actor_python_env(
+                    "nemo_rl.environments.nemo_gym.NemoGym"
+                ),
+            }
+        ).remote(config)
+        ray.get(env.health_check.remote())
+        yield env
+    finally:
+        if env is not None:
+            env.shutdown.remote()
+            ray.kill(env)
+            time.sleep(0.1)
+        context.__exit__(None, None, None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
context = patch_run_examples()
context.__enter__()
yaml_str = r"""example_multi_step_resources_server:
resources_servers:
example_multi_step:
entrypoint: app.py
domain: instruction_following
example_multi_step_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: example_multi_step_resources_server
model_server:
type: responses_api_models
name: openai_model
openai_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: ${policy_base_url}
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
"""
config = NemoGymConfig(
model_name=nemo_gym_vllm_generation.cfg["model_name"],
base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls,
initial_global_config_dict=safe_load(yaml_str),
rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN,
)
env = NemoGym.options(
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.nemo_gym.NemoGym"
),
}
).remote(config)
ray.get(env.health_check.remote())
yield env
env.shutdown.remote()
ray.kill(env)
time.sleep(0.1)
context.__exit__(None, None, None)
context = patch_run_examples()
context.__enter__()
env = None
yaml_str = r"""example_multi_step_resources_server:
resources_servers:
example_multi_step:
entrypoint: app.py
domain: instruction_following
example_multi_step_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: example_multi_step_resources_server
model_server:
type: responses_api_models
name: openai_model
openai_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: ${policy_base_url}
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
"""
config = NemoGymConfig(
model_name=nemo_gym_vllm_generation.cfg["model_name"],
base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls,
initial_global_config_dict=safe_load(yaml_str),
rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN,
)
try:
env = NemoGym.options(
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.nemo_gym.NemoGym"
),
}
).remote(config)
ray.get(env.health_check.remote())
yield env
finally:
if env is not None:
env.shutdown.remote()
ray.kill(env)
time.sleep(0.1)
context.__exit__(None, None, None)
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 247 - 295, The fixture
currently calls context.__enter__ and yields env but if actor creation or setup
fails the function exits before calling context.__exit__, leaking the patch;
wrap the setup and yield in a try/finally so context.__exit__ is always called:
call context.__enter__ first, then create config and env via
NemoGym.options(...).remote and perform ray.get(env.health_check.remote())
inside the try, yield env as before, and in the finally ensure you call
env.shutdown.remote() and ray.kill(env) only if env was created, then call
context.__exit__(None, None, None) to guarantee the patch_run_examples context
is reverted even on failures.

@guyueh1 guyueh1 self-assigned this Feb 5, 2026
@guyueh1 guyueh1 added the CI:L2 Run doctests, unit tests, functional tests, and convergence tests label Feb 9, 2026
Signed-off-by: Guyue Huang <[email protected]>
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 9, 2026
Signed-off-by: Guyue Huang <[email protected]>
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 9, 2026
@guyueh1 guyueh1 added CI:L2 Run doctests, unit tests, functional tests, and convergence tests and removed CI:L2 Run doctests, unit tests, functional tests, and convergence tests labels Feb 10, 2026
@guyueh1
Copy link
Contributor Author

guyueh1 commented Feb 10, 2026

@terrykong the pipeline has passed at commit a3cd107 can you review?

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally lgtm. small comments

model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
initial_global_config_dict=config["env"]["nemo_gym"],
rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of things:

  1. can we move the default to the yaml and avoid defaulting here
  2. can we add it to the generation config typeddict with docstring
  3. can we add asserts (==1) that say this has no effect in other places like (when nemo-gym isn't used) just so users are not under the impression that other paths are less stable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.

Comment on lines +116 to +117
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
while trial < max_retries:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
while trial < max_retries:
max_attempts, trial = self.cfg["rollout_max_attempts_to_avoid_lp_nan"], 0
while trial < max_attempts:

nit: how about naming it "attempts" instead of "retries" since it you aren't actually retrying 1 time, it's more like 0 times by default

nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe

assert self.cfg["rollout_max_attempts_to_avoid_lp_nan"] >= 1, .....

just to give a nice user error instead of skipping this while loop

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe good to put this assert at init part instead of here.

Copy link
Contributor

@yuki-97 yuki-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm and left minor comments.

model_name=policy_generation.cfg["model_name"],
base_urls=policy_generation.dp_openai_server_base_urls,
initial_global_config_dict=config["env"]["nemo_gym"],
rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.

nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also maybe good to put this assert at init part instead of here.

Co-authored-by: Terry Kong <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L2 Run doctests, unit tests, functional tests, and convergence tests super-v3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants