diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index dcf9b7dde5..ff5d13424c 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -262,6 +262,7 @@ env: should_use_nemo_gym: true should_log_nemo_gym_responses: true # If you have low logging storage, set this to false nemo_gym: # This is passed into NeMo-Gym as the initial_global_config_dict + rollout_max_attempts_to_avoid_lp_nan: 1 is_trajectory_collection: false # Set this to true to enable trajectory collection (no training). You may also want to increase `policy.generation.vllm_cfg.gpu_memory_utilization` config_paths: - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml # Required! And it must be *for_training diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index 6694d4f4d1..d571f1a93f 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -82,6 +82,14 @@ def __init__(self, cfg: NemoGymConfig): "port": self.head_server_port, } + self.rollout_max_attempts_to_avoid_lp_nan = initial_global_config_dict.pop( + "rollout_max_attempts_to_avoid_lp_nan", 1 + ) + + assert self.rollout_max_attempts_to_avoid_lp_nan >= 1, ( + "`rollout_max_attempts_to_avoid_lp_nan` must be at least 1" + ) + self.rh = RunHelper() self.rh.start( global_config_dict_parser_config=GlobalConfigDictParserConfig( @@ -110,25 +118,47 @@ async def run_rollouts( ) -> list[dict]: timer = Timer() - nemo_gym_num_rows = len(nemo_gym_examples) - nemo_gym_result_iterator = self.rch.run_examples( - examples=nemo_gym_examples, head_server_config=self.head_server_config - ) - timer.start("_run_rollouts_total") - nemo_rl_rowidxs = [] - nemo_rl_results = [] - for task in nemo_gym_result_iterator: - with timer.time(label=f"{timer_prefix}/await_results"): - nemo_gym_row, nemo_gym_result = await task - - with timer.time(label=f"{timer_prefix}/postprocess_results"): - nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( - nemo_gym_result, tokenizer - ) + max_attempts, trial = self.rollout_max_attempts_to_avoid_lp_nan, 0 + while trial < max_attempts: + nemo_gym_num_rows = len(nemo_gym_examples) + nemo_gym_result_iterator = self.rch.run_examples( + examples=nemo_gym_examples, head_server_config=self.head_server_config + ) - nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) - nemo_rl_results.append(nemo_rl_result) + nemo_rl_rowidxs = [] + nemo_rl_results = [] + for task in nemo_gym_result_iterator: + with timer.time(label=f"{timer_prefix}/await_results"): + nemo_gym_row, nemo_gym_result = await task + + with timer.time(label=f"{timer_prefix}/postprocess_results"): + nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result( + nemo_gym_result, tokenizer + ) + + nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"]) + nemo_rl_results.append(nemo_rl_result) + + # determine if generation_logprobs contain NaN; if not, break; + logprob_contains_nan = False + for nemo_rl_result in nemo_rl_results: + for message in nemo_rl_result["message_log"]: + if ( + "generation_logprobs" in message + and message["generation_logprobs"] is not None + ): + if torch.isnan(message["generation_logprobs"]).any(): + logprob_contains_nan = True + break + if logprob_contains_nan: + trial += 1 + print( + f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_attempts})" + ) + continue + else: + break nemo_rl_sort_results = [None] * nemo_gym_num_rows for rowidx, result in zip(nemo_rl_rowidxs, nemo_rl_results): diff --git a/tests/unit/environments/test_nemo_gym.py b/tests/unit/environments/test_nemo_gym.py index 9812e23d17..05ce0936ed 100644 --- a/tests/unit/environments/test_nemo_gym.py +++ b/tests/unit/environments/test_nemo_gym.py @@ -106,6 +106,7 @@ def nemo_gym(nemo_gym_vllm_generation): model: ${policy_model_name} return_token_id_information: true uses_reasoning_parser: true +rollout_max_attempts_to_avoid_lp_nan: 1 """ config = NemoGymConfig(