Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
batched_message_log_to_flat_message,
get_keys_from_message_log,
)
from nemo_rl.data.utils import extract_necessary_env_names
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
Expand Down Expand Up @@ -341,12 +342,12 @@ def setup(
# ==========================
print("\n▶ Setting up compute cluster...", flush=True)
colocated_inference = generation_config["colocated"]["enabled"]
reward_model_enabled = (
"env_name" in data_config and data_config["env_name"] == "reward_model"
)

env_name_list = extract_necessary_env_names(data_config)
rm_env_enabled = "reward_model" in env_name_list

total_nodes = cluster_config["num_nodes"]
if reward_model_enabled:
if rm_env_enabled:
rm_resource = env_configs["reward_model"]["resources"]
rm_nodes = rm_resource["num_nodes"]
rm_gpus_per_node = rm_resource["gpus_per_node"]
Expand Down Expand Up @@ -423,15 +424,15 @@ def setup(
inference_nodes = 1
# If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node
reward_gpus_to_subtract = (
rm_gpus_per_node if total_nodes == 1 and reward_model_enabled else 0
rm_gpus_per_node if total_nodes == 1 and rm_env_enabled else 0
)
train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract
assert train_gpus_per_node > 0, (
"No enough GPUs for training, "
f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}"
+ (
f" - rm_gpus_per_node:{rm_gpus_per_node}"
if total_nodes == 1 and reward_model_enabled
if total_nodes == 1 and rm_env_enabled
else ""
)
)
Expand Down
3 changes: 1 addition & 2 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
time uv run --no-sync bash ./tests/functional/grpo_multiple_datasets.sh
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh
# Re-enable once it is fixed.
# time uv run --no-sync bash ./tests/functional/grpo_rm_env.sh
time uv run --no-sync bash ./tests/functional/grpo_rm_env.sh
# Re-enable once SGLang build is fixed.
# time uv run --no-sync bash ./tests/functional/grpo_sglang.sh
time uv run --no-sync bash ./tests/functional/prorlv2.sh
Expand Down
Loading