Conversation
Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
| for server_idx in range(n_servers): | ||
| server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) | ||
| server_executor = get_executor( | ||
| cluster_config=cluster_config, | ||
| container=server_container, | ||
| num_nodes=server_config["num_nodes"], | ||
| tasks_per_node=num_server_tasks, | ||
| gpus_per_node=server_config["num_gpus"], | ||
| partition=partition, | ||
| dependencies=dependencies, | ||
| job_name=task_name, | ||
| log_dir=log_dir, | ||
| log_prefix=f"server_{server_idx}" if n_servers > 1 else "server", | ||
| extra_package_dirs=extra_package_dirs, | ||
| sbatch_kwargs=sbatch_kwargs, | ||
| heterogeneous=heterogeneous, | ||
| het_group=het_group, | ||
| total_het_groups=total_het_groups, | ||
| overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus | ||
| with_ray=False, | ||
| ray_template=ray_template, | ||
| ) | ||
| cmd_to_add = server_cmd | ||
| if cluster_config["executor"] != "slurm" and num_server_tasks > 1: | ||
| cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" | ||
| commands.append(cmd_to_add) | ||
| executors.append(server_executor) | ||
| het_group_indices.append(het_group) | ||
| het_group += 1 | ||
| LOG.info("Server %d command: %s", server_idx, server_cmd) |
There was a problem hiding this comment.
all servers launched in the loop use the same server_port from server_config, causing port conflicts when n_servers > 1
each server instance needs a unique port
| for server_idx in range(n_servers): | |
| server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) | |
| server_executor = get_executor( | |
| cluster_config=cluster_config, | |
| container=server_container, | |
| num_nodes=server_config["num_nodes"], | |
| tasks_per_node=num_server_tasks, | |
| gpus_per_node=server_config["num_gpus"], | |
| partition=partition, | |
| dependencies=dependencies, | |
| job_name=task_name, | |
| log_dir=log_dir, | |
| log_prefix=f"server_{server_idx}" if n_servers > 1 else "server", | |
| extra_package_dirs=extra_package_dirs, | |
| sbatch_kwargs=sbatch_kwargs, | |
| heterogeneous=heterogeneous, | |
| het_group=het_group, | |
| total_het_groups=total_het_groups, | |
| overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus | |
| with_ray=False, | |
| ray_template=ray_template, | |
| ) | |
| cmd_to_add = server_cmd | |
| if cluster_config["executor"] != "slurm" and num_server_tasks > 1: | |
| cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" | |
| commands.append(cmd_to_add) | |
| executors.append(server_executor) | |
| het_group_indices.append(het_group) | |
| het_group += 1 | |
| LOG.info("Server %d command: %s", server_idx, server_cmd) | |
| for server_idx in range(n_servers): | |
| # Get a unique port for each server if launching multiple | |
| current_server_config = server_config.copy() | |
| if n_servers > 1: | |
| current_server_config["server_port"] = get_free_port(strategy="random") | |
| server_cmd, num_server_tasks = get_server_command(**current_server_config, cluster_config=cluster_config) | |
| server_executor = get_executor( | |
| cluster_config=cluster_config, | |
| container=server_container, | |
| num_nodes=current_server_config["num_nodes"], | |
| tasks_per_node=num_server_tasks, | |
| gpus_per_node=current_server_config["num_gpus"], | |
| partition=partition, | |
| dependencies=dependencies, | |
| job_name=task_name, | |
| log_dir=log_dir, | |
| log_prefix=f"server_{server_idx}" if n_servers > 1 else "server", | |
| extra_package_dirs=extra_package_dirs, | |
| sbatch_kwargs=sbatch_kwargs, | |
| heterogeneous=heterogeneous, | |
| het_group=het_group, | |
| total_het_groups=total_het_groups, | |
| overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus | |
| with_ray=False, | |
| ray_template=ray_template, | |
| ) |
📝 WalkthroughWalkthroughIntroduces LLM-as-a-judge server support to GRPO workflow with new server configuration CLI parameters and multi-server scheduling logic. Adds server setup capabilities for judge services with GPU-aware task ordering and server-first execution when appropriate. Changes
Sequence DiagramsequenceDiagram
participant GRPO as GRPO CLI
participant Scheduler as Task Scheduler
participant GPU_Mgr as GPU Allocator
participant Server as Judge Server
participant Trainer as Training Task
GRPO->>GRPO: Parse server config params
GRPO->>Scheduler: add_task(server_config, n_servers=N)
Scheduler->>GPU_Mgr: Determine client_num_gpus
GPU_Mgr-->>Scheduler: GPU availability
alt Server needs GPUs & Client has none
Scheduler->>Scheduler: Schedule server tasks FIRST
Scheduler->>Server: Create N server executors
Server->>Server: Allocate GPUs
else Client has GPUs
Scheduler->>Trainer: Schedule training task
Trainer->>Trainer: Allocate GPUs
Scheduler->>Server: Add server tasks after main
end
Scheduler->>Scheduler: Set log_prefix & het_groups
Scheduler-->>GRPO: Task configuration ready
GRPO->>Server: Launch judge server instance(s)
GRPO->>Trainer: Launch training with judge reference
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@nemo_skills/pipeline/nemo_rl/grpo.py`:
- Around line 405-420: When server_type is provided but you intend to host the
model (server_address is None), ensure server_model is required and non-empty to
avoid passing model_path=None into get_server_command; add a validation check
(e.g., assert or raise ValueError) before building server_config that
server_model is not None/empty, referencing the variables server_type,
server_address, and server_model and the block that constructs server_config so
the code fails fast with a clear message instead of producing "None" in the
command.
In `@nemo_skills/pipeline/utils/exp.py`:
- Around line 530-534: The code should fail-fast on a missing num_gpus key:
replace int(server_config.get("num_gpus", 0)) with
int(server_config["num_gpus"]) in the server_needs_gpus calculation (keep the
existing server_config is not None check), so server_needs_gpus and
server_goes_first reflect the required presence of server_config["num_gpus"]
used later by get_server_command; this ensures a clear KeyError rather than
silently using 0.
- Around line 536-572: The add_server_tasks function currently mutates the
shared server_config by calling server_config.pop("container", ...); instead,
make a shallow copy of server_config at the start of add_server_tasks (e.g.,
local_server_cfg = server_config.copy()), read the container with
local_server_cfg.get("container") and if you need to remove it from what you
pass onward, delete it from the copy (del local_server_cfg["container"]) before
calling get_server_command(**local_server_cfg, cluster_config=cluster_config)
and passing local_server_cfg to other helpers (or pass container separately) so
the original server_config remains unchanged and unexpected keys are not
forwarded to get_server_command.
🧹 Nitpick comments (3)
nemo_skills/pipeline/utils/exp.py (2)
594-612:client_num_gpusis reassigned with different semantics — consider a distinct variable name.Line 531 sets
client_num_gpus = num_gpus or 0to decide server ordering/overlap. Line 594 reassigns it to0whenserver_config is not None and num_nodes == 1, changing the meaning from "does the client need GPUs at all" to "how many GPUs to allocate for the main SLURM srun." This shadowing makes the control flow hard to follow — a reader (or future editor) must track which assignment is live at each usage site.Consider using a distinct name like
main_task_gpusfor the line 594 assignment.
436-436: Consider validatingn_servers >= 1whenserver_configis provided.If
n_servers=0is passed with a non-Noneserver_config, no server tasks are created, yetserver_configis still consumed (e.g., popping"container", computingserver_goes_first). This would silently produce a misconfigured job. A simple guard early in the function would prevent this.nemo_skills/pipeline/nemo_rl/grpo.py (1)
409-411: Consider usingraise ValueErrorinstead ofassertfor argument validation.
assertstatements can be disabled withpython -O. While unlikely in a CLI context, using explicitraise ValueError(...)ortyper.BadParameter(...)is more robust and consistent with the existing validation patterns in this file (e.g., line 399).Also applies to: 428-429
| # Server configuration for LLM-as-a-judge | ||
| server_config = None | ||
| if server_type is not None: | ||
| get_random_port = should_get_random_port(server_gpus, exclusive) | ||
| if server_address is None: # we need to host the model | ||
| assert server_gpus is not None, "Need to specify server_gpus if hosting the model" | ||
| server_port = get_free_port(strategy="random") if get_random_port else 5000 | ||
|
|
||
| server_config = { | ||
| "model_path": server_model, | ||
| "server_type": server_type, | ||
| "num_gpus": server_gpus, | ||
| "num_nodes": server_nodes, | ||
| "server_args": server_args, | ||
| "server_port": server_port, | ||
| } |
There was a problem hiding this comment.
Missing validation: server_model should be required when server_type is specified.
When server_type is provided but server_model is omitted (defaults to None), model_path=None is passed into get_server_command, producing a command string containing the literal string "None". This would fail at runtime with a confusing error.
Proposed fix
server_config = None
if server_type is not None:
+ if server_model is None:
+ raise ValueError("server_model is required when server_type is specified")
get_random_port = should_get_random_port(server_gpus, exclusive)🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/nemo_rl/grpo.py` around lines 405 - 420, When
server_type is provided but you intend to host the model (server_address is
None), ensure server_model is required and non-empty to avoid passing
model_path=None into get_server_command; add a validation check (e.g., assert or
raise ValueError) before building server_config that server_model is not
None/empty, referencing the variables server_type, server_address, and
server_model and the block that constructs server_config so the code fails fast
with a clear message instead of producing "None" in the command.
| server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0 | ||
| client_num_gpus = num_gpus or 0 | ||
| # For ray heterogenous jobs, nemo-run assumes the first het group is the main task | ||
| # So we send the server last if the job needs gpus | ||
| server_goes_first = server_needs_gpus and not client_num_gpus |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Use direct dict access for num_gpus per coding guidelines.
server_config["num_gpus"] is required by get_server_command downstream. Using .get() with a default of 0 masks a missing key that would fail later with a less clear error.
Proposed fix
- server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0
+ server_needs_gpus = server_config is not None and int(server_config["num_gpus"]) > 0As per coding guidelines, "Do not use .get() for accessing dictionary keys if the code expects them to be present; use direct dictionary access dict[key] instead to allow proper error handling and fail fast with clear errors".
🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/utils/exp.py` around lines 530 - 534, The code should
fail-fast on a missing num_gpus key: replace int(server_config.get("num_gpus",
0)) with int(server_config["num_gpus"]) in the server_needs_gpus calculation
(keep the existing server_config is not None check), so server_needs_gpus and
server_goes_first reflect the required presence of server_config["num_gpus"]
used later by get_server_command; this ensures a clear KeyError rather than
silently using 0.
| def add_server_tasks(): | ||
| nonlocal het_group | ||
| # Get container once (same for all servers) | ||
| server_container = server_config.pop("container", None) | ||
| if server_container is None: | ||
| server_container = cluster_config["containers"][server_config["server_type"]] | ||
| server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) | ||
| server_executor = get_executor( | ||
| cluster_config=cluster_config, | ||
| container=server_container, | ||
| num_nodes=server_config["num_nodes"], | ||
| tasks_per_node=num_server_tasks, | ||
| gpus_per_node=server_config["num_gpus"], | ||
| partition=partition, | ||
| dependencies=dependencies, | ||
| job_name=task_name, | ||
| log_dir=log_dir, | ||
| log_prefix="server", | ||
| extra_package_dirs=extra_package_dirs, | ||
| sbatch_kwargs=sbatch_kwargs, | ||
| heterogeneous=heterogeneous, | ||
| het_group=het_group, | ||
| total_het_groups=total_het_groups, | ||
| with_ray=with_ray, | ||
| ray_template=ray_template, | ||
| ) | ||
| if cluster_config["executor"] != "slurm" and num_server_tasks > 1: | ||
| server_cmd = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" | ||
| commands.append(server_cmd) | ||
| executors.append(server_executor) | ||
| het_group_indices.append(het_group) | ||
| het_group += 1 | ||
| LOG.info("Server command: %s", server_cmd) | ||
|
|
||
| # then goes the main task(s) unless it's empty | ||
| for server_idx in range(n_servers): | ||
| server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) | ||
| server_executor = get_executor( | ||
| cluster_config=cluster_config, | ||
| container=server_container, | ||
| num_nodes=server_config["num_nodes"], | ||
| tasks_per_node=num_server_tasks, | ||
| gpus_per_node=server_config["num_gpus"], | ||
| partition=partition, | ||
| dependencies=dependencies, | ||
| job_name=task_name, | ||
| log_dir=log_dir, | ||
| log_prefix=f"server_{server_idx}" if n_servers > 1 else "server", | ||
| extra_package_dirs=extra_package_dirs, | ||
| sbatch_kwargs=sbatch_kwargs, | ||
| heterogeneous=heterogeneous, | ||
| het_group=het_group, | ||
| total_het_groups=total_het_groups, | ||
| overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus | ||
| with_ray=False, | ||
| ray_template=ray_template, | ||
| ) | ||
| cmd_to_add = server_cmd | ||
| if cluster_config["executor"] != "slurm" and num_server_tasks > 1: | ||
| cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" | ||
| commands.append(cmd_to_add) | ||
| executors.append(server_executor) | ||
| het_group_indices.append(het_group) | ||
| het_group += 1 | ||
| LOG.info("Server %d command: %s", server_idx, server_cmd) |
There was a problem hiding this comment.
server_config.pop("container") mutates the caller's dictionary — use .get() + conditional del on a copy, or pass container separately.
add_server_tasks pops "container" from the shared server_config dict. Since add_task can be called in a loop (e.g., the dependent_jobs loop in grpo.py), a caller that includes "container" in server_config would silently lose it on the second iteration. Additionally, get_server_command(**server_config, ...) on line 544 would fail if server_config ever gains an unexpected key, since pop is the only thing currently preventing "container" from being forwarded.
Consider working on a shallow copy:
Proposed fix
def add_server_tasks():
nonlocal het_group
- # Get container once (same for all servers)
- server_container = server_config.pop("container", None)
+ # Get container without mutating the original config
+ server_container = server_config.get("container", None)
if server_container is None:
server_container = cluster_config["containers"][server_config["server_type"]]
+ # Build a config without 'container' to pass to get_server_command
+ server_cmd_config = {k: v for k, v in server_config.items() if k != "container"}
for server_idx in range(n_servers):
- server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
+ server_cmd, num_server_tasks = get_server_command(**server_cmd_config, cluster_config=cluster_config)🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/utils/exp.py` around lines 536 - 572, The
add_server_tasks function currently mutates the shared server_config by calling
server_config.pop("container", ...); instead, make a shallow copy of
server_config at the start of add_server_tasks (e.g., local_server_cfg =
server_config.copy()), read the container with local_server_cfg.get("container")
and if you need to remove it from what you pass onward, delete it from the copy
(del local_server_cfg["container"]) before calling
get_server_command(**local_server_cfg, cluster_config=cluster_config) and
passing local_server_cfg to other helpers (or pass container separately) so the
original server_config remains unchanged and unexpected keys are not forwarded
to get_server_command.
| @@ -433,6 +433,7 @@ def add_task( | |||
| keep_mounts_for_sandbox=False, | |||
There was a problem hiding this comment.
@gwarmstrong do we need to update the declarative code path to reflect these changes?
There was a problem hiding this comment.
It will need to be updated when we want to use the feature on the declarative path, but at the moment I'm not sure there is value to adding it to the declarative path purely for parity sake.
Is there any way to ensure it is covered by some test case (gpu or slurm probably?) that way when we convert to declarative, we can make sure the functionality isn't dropped?
| with temporary_env_update(cluster_config, {"NEMO_SKILLS_SANDBOX_PORT": sandbox_port}): | ||
| cur_cmd = install_packages_wrap(cur_cmd, installation_command) | ||
| commands.append(cur_cmd) | ||
| client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0 |
There was a problem hiding this comment.
client_num_gpus is calculated here inside the loop, but was already defined at line 531. This shadows the outer variable and is calculated inside the wrong scope (should be outside the for loop at lines 588-617).
| client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0 | |
| client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0 |
Move this line before line 588 (before the for cur_idx, (cur_cmd... loop starts).
Allow heterogenous servers for nemo-rl jobs
Summary by CodeRabbit