Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ env:
should_use_nemo_gym: true
should_log_nemo_gym_responses: true # If you have low logging storage, set this to false
nemo_gym: # This is passed into NeMo-Gym as the initial_global_config_dict
rollout_max_attempts_to_avoid_lp_nan: 1
is_trajectory_collection: false # Set this to true to enable trajectory collection (no training). You may also want to increase `policy.generation.vllm_cfg.gpu_memory_utilization`
config_paths:
- responses_api_models/vllm_model/configs/vllm_model_for_training.yaml # Required! And it must be *for_training
Expand Down
64 changes: 47 additions & 17 deletions nemo_rl/environments/nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def __init__(self, cfg: NemoGymConfig):
"port": self.head_server_port,
}

self.rollout_max_attempts_to_avoid_lp_nan = initial_global_config_dict.pop(
"rollout_max_attempts_to_avoid_lp_nan", 1
)

assert self.rollout_max_attempts_to_avoid_lp_nan >= 1, (
"`rollout_max_attempts_to_avoid_lp_nan` must be at least 1"
)

self.rh = RunHelper()
self.rh.start(
global_config_dict_parser_config=GlobalConfigDictParserConfig(
Expand Down Expand Up @@ -110,25 +118,47 @@ async def run_rollouts(
) -> list[dict]:
timer = Timer()

nemo_gym_num_rows = len(nemo_gym_examples)
nemo_gym_result_iterator = self.rch.run_examples(
examples=nemo_gym_examples, head_server_config=self.head_server_config
)

timer.start("_run_rollouts_total")
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task

with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)
max_attempts, trial = self.rollout_max_attempts_to_avoid_lp_nan, 0
while trial < max_attempts:
nemo_gym_num_rows = len(nemo_gym_examples)
nemo_gym_result_iterator = self.rch.run_examples(
examples=nemo_gym_examples, head_server_config=self.head_server_config
)

nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)
nemo_rl_rowidxs = []
nemo_rl_results = []
for task in nemo_gym_result_iterator:
with timer.time(label=f"{timer_prefix}/await_results"):
nemo_gym_row, nemo_gym_result = await task

with timer.time(label=f"{timer_prefix}/postprocess_results"):
nemo_rl_result = self._postprocess_nemo_gym_to_nemo_rl_result(
nemo_gym_result, tokenizer
)

nemo_rl_rowidxs.append(nemo_gym_row["_rowidx"])
nemo_rl_results.append(nemo_rl_result)

# determine if generation_logprobs contain NaN; if not, break;
logprob_contains_nan = False
for nemo_rl_result in nemo_rl_results:
for message in nemo_rl_result["message_log"]:
if (
"generation_logprobs" in message
and message["generation_logprobs"] is not None
):
if torch.isnan(message["generation_logprobs"]).any():
logprob_contains_nan = True
break
if logprob_contains_nan:
trial += 1
print(
f"Generation logprobs contain NaN; retrying... (trial {trial}/{max_attempts})"
)
continue
else:
break

nemo_rl_sort_results = [None] * nemo_gym_num_rows
for rowidx, result in zip(nemo_rl_rowidxs, nemo_rl_results):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/environments/test_nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def nemo_gym(nemo_gym_vllm_generation):
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
rollout_max_attempts_to_avoid_lp_nan: 1
"""

config = NemoGymConfig(
Expand Down
Loading