-
Notifications
You must be signed in to change notification settings - Fork 249
feat: retry rollout if generation_logprobs contains NaN #1885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
📝 WalkthroughWalkthroughThe changes introduce retry logic to NeMo-Gym rollout collection to handle NaN values in generation log probabilities. A new configuration field enables retrying rollouts up to a maximum count, with NaN detection in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/nemo_gym/run_grpo_nemo_gym.py (1)
252-257:⚠️ Potential issue | 🟠 MajorAvoid code-level default for rollout_max_retries_to_avoid_lp_nan.
Using
cfg.get(..., 1)introduces a hidden default in code; please read the key directly and define the default in YAML.🔧 Suggested fix
- rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get("rollout_max_retries_to_avoid_lp_nan", 1), + rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg["rollout_max_retries_to_avoid_lp_nan"],As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values" and "Access required config values directly (e.g.,
policy_cfg['precision']) and assume they are present; do not introduce hidden defaults in code."nemo_rl/environments/nemo_gym.py (1)
27-31:⚠️ Potential issue | 🟠 MajorRemove the TypedDict default and mark the field as required.
rollout_max_retries_to_avoid_lp_nan: int = 1sets a non-None default in code, which violates the guideline that YAML is the single source of truth for configuration defaults. In TypedDict, this pattern is also inconsistent with the codebase convention of usingNotRequired[int]for optional fields. Since the field is always explicitly provided at instantiation sites (never relying on the class default), it should be declared as required without a default value.🔧 Suggested fix
class NemoGymConfig(TypedDict): model_name: str base_urls: List[str] initial_global_config_dict: Dict[str, Any] - rollout_max_retries_to_avoid_lp_nan: int = 1 + rollout_max_retries_to_avoid_lp_nan: int
🤖 Fix all issues with AI agents
In `@nemo_rl/environments/nemo_gym.py`:
- Around line 115-152: Validate and clearly define the semantics of max_retries
before entering the loop: ensure cfg["rollout_max_retries_to_avoid_lp_nan"] is
an int >= 1 (or raise a ValueError) so nemo_gym_num_rows is always defined; keep
current semantics as "max attempts" by replacing the while trial < max_retries
loop with a for attempt in range(max_retries): or explicitly document that
max_retries is the total number of attempts, and remove the off‑by‑one
ambiguity; reference the variables max_retries, trial, the while trial <
max_retries loop, and nemo_gym_num_rows when making the check and adjustment.
In `@tests/unit/environments/test_nemo_gym.py`:
- Around line 206-208: Add a Google-style docstring to the pytest fixture
nemo_gym_with_patched_run_examples describing its purpose, parameters (if any)
and return value; place it immediately below the def
nemo_gym_with_patched_run_examples(...) line and follow Google style sections
(Args:, Returns:) and mention that it yields a nemo_gym instance with
RolloutCollectionHelper.run_examples patched for tests so readers and Sphinx can
parse it.
- Around line 247-295: The fixture currently calls context.__enter__ and yields
env but if actor creation or setup fails the function exits before calling
context.__exit__, leaking the patch; wrap the setup and yield in a try/finally
so context.__exit__ is always called: call context.__enter__ first, then create
config and env via NemoGym.options(...).remote and perform
ray.get(env.health_check.remote()) inside the try, yield env as before, and in
the finally ensure you call env.shutdown.remote() and ray.kill(env) only if env
was created, then call context.__exit__(None, None, None) to guarantee the
patch_run_examples context is reverted even on failures.
- Around line 46-47: The mutable module-global run_examples_called should follow
the project's naming and test-safety conventions: rename it to
G_RUN_EXAMPLES_CALLED (upper snake case with G_ prefix) and ensure it's reset
before each test to avoid cross-test leakage; update all references to
run_examples_called in this file (e.g., where it is incremented or asserted) to
G_RUN_EXAMPLES_CALLED and add a pytest fixture or test setup that sets
G_RUN_EXAMPLES_CALLED = 0 before each test runs.
- Around line 222-239: The patched new_run_examples is unpacking awaitables
returned by orig_run_examples and yielding tuples, causing await task to fail in
run_rollouts; instead, for each awaitable "task" from orig_run_examples(self,
examples, head_server_config) create and yield a new async wrapper coroutine (or
future) that awaits the original task, injects the NaN into
result["response"]["output"] (preserve the has_generation_log_probs check and
raise ValueError if none), and then returns the (row, result) pair—i.e., leave
orig_run_examples and run_rollouts semantics intact by yielding an awaitable
that performs the mutation after awaiting the original "task".
nemo_rl/environments/nemo_gym.py
Outdated
| timer.start("_run_rollouts_total") | ||
| nemo_rl_rowidxs = [] | ||
| nemo_rl_results = [] | ||
| for task in nemo_gym_result_iterator: | ||
| with timer.time(label=f"{timer_prefix}/await_results"): | ||
| nemo_gym_row, nemo_gym_result = await task | ||
|
|
||
| with timer.time(label=f"{timer_prefix}/postprocess_results"): | ||
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
|
|
||
| nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) | ||
| nemo_rl_results.append(nemo_rl_result) | ||
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 | ||
| while trial < max_retries: | ||
|
|
||
| nemo_gym_num_rows = len(nemo_gym_examples) | ||
| nemo_gym_result_iterator = self.rch.run_examples( | ||
| examples=nemo_gym_examples, head_server_config=self.head_server_config | ||
| ) | ||
|
|
||
| nemo_rl_rowidxs = [] | ||
| nemo_rl_results = [] | ||
| for task in nemo_gym_result_iterator: | ||
| with timer.time(label=f"{timer_prefix}/await_results"): | ||
| nemo_gym_row, nemo_gym_result = await task | ||
|
|
||
| with timer.time(label=f"{timer_prefix}/postprocess_results"): | ||
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
|
|
||
| nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) | ||
| nemo_rl_results.append(nemo_rl_result) | ||
|
|
||
|
|
||
| # determine if generation_logprobs contain NaN; if not, break; | ||
| logprob_contains_nan = False | ||
| for nemo_rl_result in nemo_rl_results: | ||
| for message in nemo_rl_result["message_log"]: | ||
| if "generation_logprobs" in message and message["generation_logprobs"] is not None: | ||
| if torch.isnan(message["generation_logprobs"]).any(): | ||
| logprob_contains_nan = True | ||
| break | ||
| if logprob_contains_nan: | ||
| trial += 1 | ||
| print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})") | ||
| continue | ||
| else: | ||
| break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify retry semantics and guard against non‑positive values.
while trial < max_retries treats the setting as max attempts; if users interpret it as “retries,” it’s off-by-one. Also, max_retries <= 0 leaves nemo_gym_num_rows undefined. Please validate and clearly define semantics.
🔧 Suggested fix (validation; keep current semantics)
- max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0
- while trial < max_retries:
+ max_retries = self.cfg["rollout_max_retries_to_avoid_lp_nan"]
+ if max_retries < 1:
+ raise ValueError("rollout_max_retries_to_avoid_lp_nan must be >= 1")
+ trial = 0
+ while trial < max_retries:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| timer.start("_run_rollouts_total") | |
| nemo_rl_rowidxs = [] | |
| nemo_rl_results = [] | |
| for task in nemo_gym_result_iterator: | |
| with timer.time(label=f"{timer_prefix}/await_results"): | |
| nemo_gym_row, nemo_gym_result = await task | |
| with timer.time(label=f"{timer_prefix}/postprocess_results"): | |
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | |
| nemo_gym_result, tokenizer | |
| ) | |
| nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) | |
| nemo_rl_results.append(nemo_rl_result) | |
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 | |
| while trial < max_retries: | |
| nemo_gym_num_rows = len(nemo_gym_examples) | |
| nemo_gym_result_iterator = self.rch.run_examples( | |
| examples=nemo_gym_examples, head_server_config=self.head_server_config | |
| ) | |
| nemo_rl_rowidxs = [] | |
| nemo_rl_results = [] | |
| for task in nemo_gym_result_iterator: | |
| with timer.time(label=f"{timer_prefix}/await_results"): | |
| nemo_gym_row, nemo_gym_result = await task | |
| with timer.time(label=f"{timer_prefix}/postprocess_results"): | |
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | |
| nemo_gym_result, tokenizer | |
| ) | |
| nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) | |
| nemo_rl_results.append(nemo_rl_result) | |
| # determine if generation_logprobs contain NaN; if not, break; | |
| logprob_contains_nan = False | |
| for nemo_rl_result in nemo_rl_results: | |
| for message in nemo_rl_result["message_log"]: | |
| if "generation_logprobs" in message and message["generation_logprobs"] is not None: | |
| if torch.isnan(message["generation_logprobs"]).any(): | |
| logprob_contains_nan = True | |
| break | |
| if logprob_contains_nan: | |
| trial += 1 | |
| print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})") | |
| continue | |
| else: | |
| break | |
| timer.start("_run_rollouts_total") | |
| max_retries = self.cfg["rollout_max_retries_to_avoid_lp_nan"] | |
| if max_retries < 1: | |
| raise ValueError("rollout_max_retries_to_avoid_lp_nan must be >= 1") | |
| trial = 0 | |
| while trial < max_retries: | |
| nemo_gym_num_rows = len(nemo_gym_examples) | |
| nemo_gym_result_iterator = self.rch.run_examples( | |
| examples=nemo_gym_examples, head_server_config=self.head_server_config | |
| ) | |
| nemo_rl_rowidxs = [] | |
| nemo_rl_results = [] | |
| for task in nemo_gym_result_iterator: | |
| with timer.time(label=f"{timer_prefix}/await_results"): | |
| nemo_gym_row, nemo_gym_result = await task | |
| with timer.time(label=f"{timer_prefix}/postprocess_results"): | |
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | |
| nemo_gym_result, tokenizer | |
| ) | |
| nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) | |
| nemo_rl_results.append(nemo_rl_result) | |
| # determine if generation_logprobs contain NaN; if not, break; | |
| logprob_contains_nan = False | |
| for nemo_rl_result in nemo_rl_results: | |
| for message in nemo_rl_result["message_log"]: | |
| if "generation_logprobs" in message and message["generation_logprobs"] is not None: | |
| if torch.isnan(message["generation_logprobs"]).any(): | |
| logprob_contains_nan = True | |
| break | |
| if logprob_contains_nan: | |
| trial += 1 | |
| print(f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_retries})") | |
| continue | |
| else: | |
| break |
🤖 Prompt for AI Agents
In `@nemo_rl/environments/nemo_gym.py` around lines 115 - 152, Validate and
clearly define the semantics of max_retries before entering the loop: ensure
cfg["rollout_max_retries_to_avoid_lp_nan"] is an int >= 1 (or raise a
ValueError) so nemo_gym_num_rows is always defined; keep current semantics as
"max attempts" by replacing the while trial < max_retries loop with a for
attempt in range(max_retries): or explicitly document that max_retries is the
total number of attempts, and remove the off‑by‑one ambiguity; reference the
variables max_retries, trial, the while trial < max_retries loop, and
nemo_gym_num_rows when making the check and adjustment.
| NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3 | ||
| run_examples_called = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use G_ prefix for mutable global counter and reset per test.
run_examples_called is a mutable module global; rename to the required G_-prefixed upper snake case and reset it in the fixture to prevent cross‑test leakage.
🔧 Suggested fix
NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3
-run_examples_called = 0
+G_RUN_EXAMPLES_CALLED = 0
@@
- global run_examples_called
- run_examples_called += 1
+ global G_RUN_EXAMPLES_CALLED
+ G_RUN_EXAMPLES_CALLED += 1
@@
- context = patch_run_examples()
+ global G_RUN_EXAMPLES_CALLED
+ G_RUN_EXAMPLES_CALLED = 0
+ context = patch_run_examples()
@@
- assert run_examples_called == NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN
+ assert G_RUN_EXAMPLES_CALLED == NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NANAs per coding guidelines, "Use upper snake_case with G prefix for global variables, e.g., G_MY_GLOBAL."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3 | |
| run_examples_called = 0 | |
| NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN = 3 | |
| G_RUN_EXAMPLES_CALLED = 0 |
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 46 - 47, The mutable
module-global run_examples_called should follow the project's naming and
test-safety conventions: rename it to G_RUN_EXAMPLES_CALLED (upper snake case
with G_ prefix) and ensure it's reset before each test to avoid cross-test
leakage; update all references to run_examples_called in this file (e.g., where
it is incremented or asserted) to G_RUN_EXAMPLES_CALLED and add a pytest fixture
or test setup that sets G_RUN_EXAMPLES_CALLED = 0 before each test runs.
| @pytest.fixture(scope="function") | ||
| def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation): | ||
| from nemo_gym.rollout_collection import RolloutCollectionHelper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add Google‑style docstring to the new fixture.
The new fixture lacks a docstring; please add a Google‑style docstring for clarity and consistency.
✍️ Suggested fix
def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation):
+ """Fixture that patches RolloutCollectionHelper.run_examples to inject NaN logprobs."""As per coding guidelines, "Use Google style docstrings for classes and functions, which can be parsed by Sphinx."
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @pytest.fixture(scope="function") | |
| def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation): | |
| from nemo_gym.rollout_collection import RolloutCollectionHelper | |
| `@pytest.fixture`(scope="function") | |
| def nemo_gym_with_patched_run_examples(nemo_gym_vllm_generation): | |
| """Fixture that patches RolloutCollectionHelper.run_examples to inject NaN logprobs.""" | |
| from nemo_gym.rollout_collection import RolloutCollectionHelper |
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 206 - 208, Add a
Google-style docstring to the pytest fixture nemo_gym_with_patched_run_examples
describing its purpose, parameters (if any) and return value; place it
immediately below the def nemo_gym_with_patched_run_examples(...) line and
follow Google style sections (Args:, Returns:) and mention that it yields a
nemo_gym instance with RolloutCollectionHelper.run_examples patched for tests so
readers and Sphinx can parse it.
| def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]: | ||
| print("[guyueh debug] calling new_run_examples: ") | ||
| global run_examples_called | ||
| run_examples_called += 1 | ||
|
|
||
| for task in orig_run_examples(self, examples, head_server_config): | ||
| row, result = task | ||
| # insert a NaN in the generation_log_probs | ||
| has_generation_log_probs = False | ||
| for i in range(len(result["response"]["output"])): | ||
| if "generation_log_probs" in result["response"]["output"][i]: | ||
| result["response"]["output"][i]["generation_log_probs"][-1] = float('nan') | ||
| has_generation_log_probs = True | ||
| break | ||
|
|
||
| if not has_generation_log_probs: | ||
| raise ValueError("No generation_log_probs found in the result") | ||
| yield row, result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find the test file and understand its context
find . -name "test_nemo_gym.py" -type fRepository: NVIDIA-NeMo/RL
Length of output: 100
🏁 Script executed:
# Let's look at the structure of the test file around lines 222-239
head -n 300 tests/unit/environments/test_nemo_gym.py | tail -n 100Repository: NVIDIA-NeMo/RL
Length of output: 3698
🏁 Script executed:
# Search for the definition of orig_run_examples and how it's used
rg "orig_run_examples" -B 3 -A 3Repository: NVIDIA-NeMo/RL
Length of output: 1854
🏁 Script executed:
# Search for RolloutCollectionHelper to understand run_examples behavior
rg "RolloutCollectionHelper" -B 2 -A 5Repository: NVIDIA-NeMo/RL
Length of output: 3429
🏁 Script executed:
# Look for run_rollouts to understand how it awaits tasks
rg "run_rollouts" -B 3 -A 10Repository: NVIDIA-NeMo/RL
Length of output: 6332
🏁 Script executed:
# Get more context around run_rollouts to see how the iterator is consumed
sed -n '105,200p' nemo_rl/environments/nemo_gym.pyRepository: NVIDIA-NeMo/RL
Length of output: 4423
🏁 Script executed:
# Look for await or asyncio patterns in nemo_gym.py
rg "await|asyncio" nemo_rl/environments/nemo_gym.py -B 2 -A 2Repository: NVIDIA-NeMo/RL
Length of output: 354
Patched run_examples must yield awaitables, not tuples.
The run_rollouts method awaits each task returned from run_examples (line: nemo_gym_row, nemo_gym_result = await task). The patched method unpacks the awaitable from orig_run_examples and yields tuples instead, which will fail at runtime when await task is called on a tuple.
Wrap the awaitable and inject NaN after awaiting:
Suggested fix
- def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]:
+ def new_run_examples(self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None) -> Iterator[Future]:
print("[guyueh debug] calling new_run_examples: ")
global run_examples_called
run_examples_called += 1
-
- for task in orig_run_examples(self, examples, head_server_config):
- row, result = task
- # insert a NaN in the generation_log_probs
- has_generation_log_probs = False
- for i in range(len(result["response"]["output"])):
- if "generation_log_probs" in result["response"]["output"][i]:
- result["response"]["output"][i]["generation_log_probs"][-1] = float('nan')
- has_generation_log_probs = True
- break
-
- if not has_generation_log_probs:
- raise ValueError("No generation_log_probs found in the result")
- yield row, result
+ async def _wrap(task: Future):
+ row, result = await task
+ # insert a NaN in the generation_log_probs
+ has_generation_log_probs = False
+ for i in range(len(result["response"]["output"])):
+ if "generation_log_probs" in result["response"]["output"][i]:
+ result["response"]["output"][i]["generation_log_probs"][-1] = float("nan")
+ has_generation_log_probs = True
+ break
+
+ if not has_generation_log_probs:
+ raise ValueError("No generation_log_probs found in the result")
+ return row, result
+
+ for task in orig_run_examples(self, examples, head_server_config):
+ yield _wrap(task)🧰 Tools
🪛 Ruff (0.14.14)
[warning] 238-238: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 222 - 239, The patched
new_run_examples is unpacking awaitables returned by orig_run_examples and
yielding tuples, causing await task to fail in run_rollouts; instead, for each
awaitable "task" from orig_run_examples(self, examples, head_server_config)
create and yield a new async wrapper coroutine (or future) that awaits the
original task, injects the NaN into result["response"]["output"] (preserve the
has_generation_log_probs check and raise ValueError if none), and then returns
the (row, result) pair—i.e., leave orig_run_examples and run_rollouts semantics
intact by yielding an awaitable that performs the mutation after awaiting the
original "task".
| context = patch_run_examples() | ||
| context.__enter__() | ||
|
|
||
| yaml_str = r"""example_multi_step_resources_server: | ||
| resources_servers: | ||
| example_multi_step: | ||
| entrypoint: app.py | ||
| domain: instruction_following | ||
| example_multi_step_simple_agent: | ||
| responses_api_agents: | ||
| simple_agent: | ||
| entrypoint: app.py | ||
| resources_server: | ||
| type: resources_servers | ||
| name: example_multi_step_resources_server | ||
| model_server: | ||
| type: responses_api_models | ||
| name: openai_model | ||
| openai_model: | ||
| responses_api_models: | ||
| vllm_model: | ||
| entrypoint: app.py | ||
| base_url: ${policy_base_url} | ||
| api_key: ${policy_api_key} | ||
| model: ${policy_model_name} | ||
| return_token_id_information: true | ||
| uses_reasoning_parser: true | ||
| """ | ||
|
|
||
| config = NemoGymConfig( | ||
| model_name=nemo_gym_vllm_generation.cfg["model_name"], | ||
| base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls, | ||
| initial_global_config_dict=safe_load(yaml_str), | ||
| rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN, | ||
| ) | ||
| env = NemoGym.options( | ||
| runtime_env={ | ||
| "py_executable": get_actor_python_env( | ||
| "nemo_rl.environments.nemo_gym.NemoGym" | ||
| ), | ||
| } | ||
| ).remote(config) | ||
| ray.get(env.health_check.remote()) | ||
| yield env | ||
| env.shutdown.remote() | ||
| ray.kill(env) | ||
| time.sleep(0.1) | ||
| context.__exit__(None, None, None) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guarantee the run_examples patch is reverted even if fixture setup fails.
If actor creation fails before teardown, context.__exit__ is never called and the patch can leak into other tests. Wrap the fixture body in try/finally to ensure cleanup.
🧹 Suggested fix
- context = patch_run_examples()
- context.__enter__()
+ context = patch_run_examples()
+ context.__enter__()
+ env = None
yaml_str = r"""example_multi_step_resources_server:
@@
- env = NemoGym.options(
- runtime_env={
- "py_executable": get_actor_python_env(
- "nemo_rl.environments.nemo_gym.NemoGym"
- ),
- }
- ).remote(config)
- ray.get(env.health_check.remote())
- yield env
- env.shutdown.remote()
- ray.kill(env)
- time.sleep(0.1)
- context.__exit__(None, None, None)
+ try:
+ env = NemoGym.options(
+ runtime_env={
+ "py_executable": get_actor_python_env(
+ "nemo_rl.environments.nemo_gym.NemoGym"
+ ),
+ }
+ ).remote(config)
+ ray.get(env.health_check.remote())
+ yield env
+ finally:
+ if env is not None:
+ env.shutdown.remote()
+ ray.kill(env)
+ time.sleep(0.1)
+ context.__exit__(None, None, None)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| context = patch_run_examples() | |
| context.__enter__() | |
| yaml_str = r"""example_multi_step_resources_server: | |
| resources_servers: | |
| example_multi_step: | |
| entrypoint: app.py | |
| domain: instruction_following | |
| example_multi_step_simple_agent: | |
| responses_api_agents: | |
| simple_agent: | |
| entrypoint: app.py | |
| resources_server: | |
| type: resources_servers | |
| name: example_multi_step_resources_server | |
| model_server: | |
| type: responses_api_models | |
| name: openai_model | |
| openai_model: | |
| responses_api_models: | |
| vllm_model: | |
| entrypoint: app.py | |
| base_url: ${policy_base_url} | |
| api_key: ${policy_api_key} | |
| model: ${policy_model_name} | |
| return_token_id_information: true | |
| uses_reasoning_parser: true | |
| """ | |
| config = NemoGymConfig( | |
| model_name=nemo_gym_vllm_generation.cfg["model_name"], | |
| base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls, | |
| initial_global_config_dict=safe_load(yaml_str), | |
| rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN, | |
| ) | |
| env = NemoGym.options( | |
| runtime_env={ | |
| "py_executable": get_actor_python_env( | |
| "nemo_rl.environments.nemo_gym.NemoGym" | |
| ), | |
| } | |
| ).remote(config) | |
| ray.get(env.health_check.remote()) | |
| yield env | |
| env.shutdown.remote() | |
| ray.kill(env) | |
| time.sleep(0.1) | |
| context.__exit__(None, None, None) | |
| context = patch_run_examples() | |
| context.__enter__() | |
| env = None | |
| yaml_str = r"""example_multi_step_resources_server: | |
| resources_servers: | |
| example_multi_step: | |
| entrypoint: app.py | |
| domain: instruction_following | |
| example_multi_step_simple_agent: | |
| responses_api_agents: | |
| simple_agent: | |
| entrypoint: app.py | |
| resources_server: | |
| type: resources_servers | |
| name: example_multi_step_resources_server | |
| model_server: | |
| type: responses_api_models | |
| name: openai_model | |
| openai_model: | |
| responses_api_models: | |
| vllm_model: | |
| entrypoint: app.py | |
| base_url: ${policy_base_url} | |
| api_key: ${policy_api_key} | |
| model: ${policy_model_name} | |
| return_token_id_information: true | |
| uses_reasoning_parser: true | |
| """ | |
| config = NemoGymConfig( | |
| model_name=nemo_gym_vllm_generation.cfg["model_name"], | |
| base_urls=nemo_gym_vllm_generation.dp_openai_server_base_urls, | |
| initial_global_config_dict=safe_load(yaml_str), | |
| rollout_max_retries_to_avoid_lp_nan=NEMO_GYM_ROLLOUT_MAX_RETRIES_TO_AVOID_LP_NAN, | |
| ) | |
| try: | |
| env = NemoGym.options( | |
| runtime_env={ | |
| "py_executable": get_actor_python_env( | |
| "nemo_rl.environments.nemo_gym.NemoGym" | |
| ), | |
| } | |
| ).remote(config) | |
| ray.get(env.health_check.remote()) | |
| yield env | |
| finally: | |
| if env is not None: | |
| env.shutdown.remote() | |
| ray.kill(env) | |
| time.sleep(0.1) | |
| context.__exit__(None, None, None) |
🤖 Prompt for AI Agents
In `@tests/unit/environments/test_nemo_gym.py` around lines 247 - 295, The fixture
currently calls context.__enter__ and yields env but if actor creation or setup
fails the function exits before calling context.__exit__, leaking the patch;
wrap the setup and yield in a try/finally so context.__exit__ is always called:
call context.__enter__ first, then create config and env via
NemoGym.options(...).remote and perform ray.get(env.health_check.remote())
inside the try, yield env as before, and in the finally ensure you call
env.shutdown.remote() and ray.kill(env) only if env was created, then call
context.__exit__(None, None, None) to guarantee the patch_run_examples context
is reverted even on failures.
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
|
@terrykong the pipeline has passed at commit a3cd107 can you review? |
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally lgtm. small comments
| model_name=policy_generation.cfg["model_name"], | ||
| base_urls=policy_generation.dp_openai_server_base_urls, | ||
| initial_global_config_dict=config["env"]["nemo_gym"], | ||
| rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of things:
- can we move the default to the yaml and avoid defaulting here
- can we add it to the generation config typeddict with docstring
- can we add asserts (==1) that say this has no effect in other places like (when nemo-gym isn't used) just so users are not under the impression that other paths are less stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 | ||
| while trial < max_retries: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 | |
| while trial < max_retries: | |
| max_attempts, trial = self.cfg["rollout_max_attempts_to_avoid_lp_nan"], 0 | |
| while trial < max_attempts: |
nit: how about naming it "attempts" instead of "retries" since it you aren't actually retrying 1 time, it's more like 0 times by default
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe
assert self.cfg["rollout_max_attempts_to_avoid_lp_nan"] >= 1, .....
just to give a nice user error instead of skipping this while loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe good to put this assert at init part instead of here.
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm and left minor comments.
| model_name=policy_generation.cfg["model_name"], | ||
| base_urls=policy_generation.dp_openai_server_base_urls, | ||
| initial_global_config_dict=config["env"]["nemo_gym"], | ||
| rollout_max_retries_to_avoid_lp_nan=policy_generation.cfg.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about moving rollout_max_retries_to_avoid_lp_nan to gym env config instead of putting it at generation config? since it's only used for gym and it's implemented in gym env.
| nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( | ||
| nemo_gym_result, tokenizer | ||
| ) | ||
| max_retries, trial = self.cfg["rollout_max_retries_to_avoid_lp_nan"], 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also maybe good to put this assert at init part instead of here.
Co-authored-by: Terry Kong <[email protected]> Signed-off-by: Guyue Huang <[email protected]>
What does this PR do ?
In nemo_gym environment, when generation_logprobs contains NaN, retry rollout.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests