Skip to content

Comments

Allow het servers for nemo-rl jobs#1223

Open
smahdavi4 wants to merge 5 commits intomainfrom
smahdavi/het-job-judg
Open

Allow het servers for nemo-rl jobs#1223
smahdavi4 wants to merge 5 commits intomainfrom
smahdavi/het-job-judg

Conversation

@smahdavi4
Copy link
Collaborator

@smahdavi4 smahdavi4 commented Feb 9, 2026

Allow heterogenous servers for nemo-rl jobs

Summary by CodeRabbit

  • New Features
    • Added LLM-as-a-judge support to GRPO workflow with flexible server configuration options.
    • New CLI parameters for judge server setup, including model selection, address, deployment type, GPU allocation, and multi-server management.
    • Support for both local and remote judge server deployment with automatic port allocation.

Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
Signed-off-by: Sadegh Mahdavi <smahdavi@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +543 to +572
for server_idx in range(n_servers):
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)
cmd_to_add = server_cmd
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(cmd_to_add)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server %d command: %s", server_idx, server_cmd)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all servers launched in the loop use the same server_port from server_config, causing port conflicts when n_servers > 1

each server instance needs a unique port

Suggested change
for server_idx in range(n_servers):
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)
cmd_to_add = server_cmd
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(cmd_to_add)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server %d command: %s", server_idx, server_cmd)
for server_idx in range(n_servers):
# Get a unique port for each server if launching multiple
current_server_config = server_config.copy()
if n_servers > 1:
current_server_config["server_port"] = get_free_port(strategy="random")
server_cmd, num_server_tasks = get_server_command(**current_server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=current_server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=current_server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

📝 Walkthrough

Walkthrough

Introduces LLM-as-a-judge server support to GRPO workflow with new server configuration CLI parameters and multi-server scheduling logic. Adds server setup capabilities for judge services with GPU-aware task ordering and server-first execution when appropriate.

Changes

Cohort / File(s) Summary
GRPO Server Configuration
nemo_skills/pipeline/nemo_rl/grpo.py
Adds seven new server-related CLI parameters (server_model, server_address, server_type, server_gpus, server_nodes, n_servers, server_args) to support judge service orchestration. Implements server configuration logic with local or remote hosting support, port allocation, and JUDGE_SERVER_ARGS environment injection. Threads server config and n_servers into task creation.
Task Scheduling Multi-Server Support
nemo_skills/pipeline/utils/exp.py
Introduces multi-server scheduling with new n_servers parameter. Adds GPU-aware server-first scheduling logic, adjusts het_groups calculation for multiple servers, and reworks server/main task ordering. Implements per-server command construction with server-specific log prefixes and GPU overlap handling based on client GPU usage. Creates multiple server executors when n_servers > 1.

Sequence Diagram

sequenceDiagram
    participant GRPO as GRPO CLI
    participant Scheduler as Task Scheduler
    participant GPU_Mgr as GPU Allocator
    participant Server as Judge Server
    participant Trainer as Training Task

    GRPO->>GRPO: Parse server config params
    GRPO->>Scheduler: add_task(server_config, n_servers=N)
    
    Scheduler->>GPU_Mgr: Determine client_num_gpus
    GPU_Mgr-->>Scheduler: GPU availability
    
    alt Server needs GPUs & Client has none
        Scheduler->>Scheduler: Schedule server tasks FIRST
        Scheduler->>Server: Create N server executors
        Server->>Server: Allocate GPUs
    else Client has GPUs
        Scheduler->>Trainer: Schedule training task
        Trainer->>Trainer: Allocate GPUs
        Scheduler->>Server: Add server tasks after main
    end
    
    Scheduler->>Scheduler: Set log_prefix & het_groups
    Scheduler-->>GRPO: Task configuration ready
    
    GRPO->>Server: Launch judge server instance(s)
    GRPO->>Trainer: Launch training with judge reference
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main changes: adding support for heterogeneous servers to nemo-rl jobs via server configuration parameters and multi-server scheduling logic in GRPO workflow.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch smahdavi/het-job-judg

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@nemo_skills/pipeline/nemo_rl/grpo.py`:
- Around line 405-420: When server_type is provided but you intend to host the
model (server_address is None), ensure server_model is required and non-empty to
avoid passing model_path=None into get_server_command; add a validation check
(e.g., assert or raise ValueError) before building server_config that
server_model is not None/empty, referencing the variables server_type,
server_address, and server_model and the block that constructs server_config so
the code fails fast with a clear message instead of producing "None" in the
command.

In `@nemo_skills/pipeline/utils/exp.py`:
- Around line 530-534: The code should fail-fast on a missing num_gpus key:
replace int(server_config.get("num_gpus", 0)) with
int(server_config["num_gpus"]) in the server_needs_gpus calculation (keep the
existing server_config is not None check), so server_needs_gpus and
server_goes_first reflect the required presence of server_config["num_gpus"]
used later by get_server_command; this ensures a clear KeyError rather than
silently using 0.
- Around line 536-572: The add_server_tasks function currently mutates the
shared server_config by calling server_config.pop("container", ...); instead,
make a shallow copy of server_config at the start of add_server_tasks (e.g.,
local_server_cfg = server_config.copy()), read the container with
local_server_cfg.get("container") and if you need to remove it from what you
pass onward, delete it from the copy (del local_server_cfg["container"]) before
calling get_server_command(**local_server_cfg, cluster_config=cluster_config)
and passing local_server_cfg to other helpers (or pass container separately) so
the original server_config remains unchanged and unexpected keys are not
forwarded to get_server_command.
🧹 Nitpick comments (3)
nemo_skills/pipeline/utils/exp.py (2)

594-612: client_num_gpus is reassigned with different semantics — consider a distinct variable name.

Line 531 sets client_num_gpus = num_gpus or 0 to decide server ordering/overlap. Line 594 reassigns it to 0 when server_config is not None and num_nodes == 1, changing the meaning from "does the client need GPUs at all" to "how many GPUs to allocate for the main SLURM srun." This shadowing makes the control flow hard to follow — a reader (or future editor) must track which assignment is live at each usage site.

Consider using a distinct name like main_task_gpus for the line 594 assignment.


436-436: Consider validating n_servers >= 1 when server_config is provided.

If n_servers=0 is passed with a non-None server_config, no server tasks are created, yet server_config is still consumed (e.g., popping "container", computing server_goes_first). This would silently produce a misconfigured job. A simple guard early in the function would prevent this.

nemo_skills/pipeline/nemo_rl/grpo.py (1)

409-411: Consider using raise ValueError instead of assert for argument validation.

assert statements can be disabled with python -O. While unlikely in a CLI context, using explicit raise ValueError(...) or typer.BadParameter(...) is more robust and consistent with the existing validation patterns in this file (e.g., line 399).

Also applies to: 428-429

Comment on lines +405 to +420
# Server configuration for LLM-as-a-judge
server_config = None
if server_type is not None:
get_random_port = should_get_random_port(server_gpus, exclusive)
if server_address is None: # we need to host the model
assert server_gpus is not None, "Need to specify server_gpus if hosting the model"
server_port = get_free_port(strategy="random") if get_random_port else 5000

server_config = {
"model_path": server_model,
"server_type": server_type,
"num_gpus": server_gpus,
"num_nodes": server_nodes,
"server_args": server_args,
"server_port": server_port,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: server_model should be required when server_type is specified.

When server_type is provided but server_model is omitted (defaults to None), model_path=None is passed into get_server_command, producing a command string containing the literal string "None". This would fail at runtime with a confusing error.

Proposed fix
     server_config = None
     if server_type is not None:
+        if server_model is None:
+            raise ValueError("server_model is required when server_type is specified")
         get_random_port = should_get_random_port(server_gpus, exclusive)
🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/nemo_rl/grpo.py` around lines 405 - 420, When
server_type is provided but you intend to host the model (server_address is
None), ensure server_model is required and non-empty to avoid passing
model_path=None into get_server_command; add a validation check (e.g., assert or
raise ValueError) before building server_config that server_model is not
None/empty, referencing the variables server_type, server_address, and
server_model and the block that constructs server_config so the code fails fast
with a clear message instead of producing "None" in the command.

Comment on lines +530 to +534
server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0
client_num_gpus = num_gpus or 0
# For ray heterogenous jobs, nemo-run assumes the first het group is the main task
# So we send the server last if the job needs gpus
server_goes_first = server_needs_gpus and not client_num_gpus
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use direct dict access for num_gpus per coding guidelines.

server_config["num_gpus"] is required by get_server_command downstream. Using .get() with a default of 0 masks a missing key that would fail later with a less clear error.

Proposed fix
-    server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0
+    server_needs_gpus = server_config is not None and int(server_config["num_gpus"]) > 0

As per coding guidelines, "Do not use .get() for accessing dictionary keys if the code expects them to be present; use direct dictionary access dict[key] instead to allow proper error handling and fail fast with clear errors".

🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/utils/exp.py` around lines 530 - 534, The code should
fail-fast on a missing num_gpus key: replace int(server_config.get("num_gpus",
0)) with int(server_config["num_gpus"]) in the server_needs_gpus calculation
(keep the existing server_config is not None check), so server_needs_gpus and
server_goes_first reflect the required presence of server_config["num_gpus"]
used later by get_server_command; this ensures a clear KeyError rather than
silently using 0.

Comment on lines +536 to +572
def add_server_tasks():
nonlocal het_group
# Get container once (same for all servers)
server_container = server_config.pop("container", None)
if server_container is None:
server_container = cluster_config["containers"][server_config["server_type"]]
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix="server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
with_ray=with_ray,
ray_template=ray_template,
)
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
server_cmd = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(server_cmd)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server command: %s", server_cmd)

# then goes the main task(s) unless it's empty
for server_idx in range(n_servers):
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)
cmd_to_add = server_cmd
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(cmd_to_add)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server %d command: %s", server_idx, server_cmd)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

server_config.pop("container") mutates the caller's dictionary — use .get() + conditional del on a copy, or pass container separately.

add_server_tasks pops "container" from the shared server_config dict. Since add_task can be called in a loop (e.g., the dependent_jobs loop in grpo.py), a caller that includes "container" in server_config would silently lose it on the second iteration. Additionally, get_server_command(**server_config, ...) on line 544 would fail if server_config ever gains an unexpected key, since pop is the only thing currently preventing "container" from being forwarded.

Consider working on a shallow copy:

Proposed fix
     def add_server_tasks():
         nonlocal het_group
-        # Get container once (same for all servers)
-        server_container = server_config.pop("container", None)
+        # Get container without mutating the original config
+        server_container = server_config.get("container", None)
         if server_container is None:
             server_container = cluster_config["containers"][server_config["server_type"]]
+        # Build a config without 'container' to pass to get_server_command
+        server_cmd_config = {k: v for k, v in server_config.items() if k != "container"}
 
         for server_idx in range(n_servers):
-            server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
+            server_cmd, num_server_tasks = get_server_command(**server_cmd_config, cluster_config=cluster_config)
🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/utils/exp.py` around lines 536 - 572, The
add_server_tasks function currently mutates the shared server_config by calling
server_config.pop("container", ...); instead, make a shallow copy of
server_config at the start of add_server_tasks (e.g., local_server_cfg =
server_config.copy()), read the container with local_server_cfg.get("container")
and if you need to remove it from what you pass onward, delete it from the copy
(del local_server_cfg["container"]) before calling
get_server_command(**local_server_cfg, cluster_config=cluster_config) and
passing local_server_cfg to other helpers (or pass container separately) so the
original server_config remains unchanged and unexpected keys are not forwarded
to get_server_command.

@@ -433,6 +433,7 @@ def add_task(
keep_mounts_for_sandbox=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gwarmstrong do we need to update the declarative code path to reflect these changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will need to be updated when we want to use the feature on the declarative path, but at the moment I'm not sure there is value to adding it to the declarative path purely for parity sake.

Is there any way to ensure it is covered by some test case (gpu or slurm probably?) that way when we convert to declarative, we can make sure the functionality isn't dropped?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

with temporary_env_update(cluster_config, {"NEMO_SKILLS_SANDBOX_PORT": sandbox_port}):
cur_cmd = install_packages_wrap(cur_cmd, installation_command)
commands.append(cur_cmd)
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

client_num_gpus is calculated here inside the loop, but was already defined at line 531. This shadows the outer variable and is calculated inside the wrong scope (should be outside the for loop at lines 588-617).

Suggested change
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0

Move this line before line 588 (before the for cur_idx, (cur_cmd... loop starts).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants