Skip to content

Commit 6693470

Browse files
committed
Honor executor srun_args for Ray command srun
1 parent c41c729 commit 6693470

3 files changed

Lines changed: 39 additions & 7 deletions

File tree

nemo_run/run/ray/slurm.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,19 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
278278

279279
return " ".join(_srun_flags)
280280

281+
def get_command_srun_args() -> str:
282+
if (
283+
self.executor.run_as_group
284+
and self.executor.heterogeneous
285+
and self.executor.resource_group
286+
and self.executor.resource_group[0].srun_args is not None
287+
):
288+
command_srun_args = self.executor.resource_group[0].srun_args
289+
else:
290+
command_srun_args = self.executor.srun_args or []
291+
292+
return " ".join(shlex.quote(arg) for arg in command_srun_args)
293+
281294
ray_log_prefix = job_details.ray_log_prefix
282295
vars_to_fill = {
283296
"sbatch_flags": sbatch_flags,
@@ -296,6 +309,7 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str:
296309
"ray_log_prefix": ray_log_prefix,
297310
"heterogeneous": self.executor.heterogeneous,
298311
"resource_group": self.executor.resource_group if self.executor.heterogeneous else [],
312+
"command_srun_args": get_command_srun_args(),
299313
}
300314

301315
if self.command_groups:
@@ -1257,9 +1271,9 @@ def start(
12571271
if isinstance(self.executor.tunnel, SSHTunnel):
12581272
# Rsync workdir honouring .gitignore
12591273
self.executor.tunnel.connect()
1260-
assert self.executor.tunnel.session is not None, (
1261-
"Tunnel session is not connected"
1262-
)
1274+
assert (
1275+
self.executor.tunnel.session is not None
1276+
), "Tunnel session is not connected"
12631277
rsync(
12641278
self.executor.tunnel.session,
12651279
workdir,
@@ -1314,9 +1328,9 @@ def start(
13141328

13151329
if isinstance(self.executor.tunnel, SSHTunnel):
13161330
self.executor.tunnel.connect()
1317-
assert self.executor.tunnel.session is not None, (
1318-
"Tunnel session is not connected"
1319-
)
1331+
assert (
1332+
self.executor.tunnel.session is not None
1333+
), "Tunnel session is not connected"
13201334
rsync(
13211335
self.executor.tunnel.session,
13221336
os.path.join(local_code_extraction_path, ""),

nemo_run/run/ray/templates/ray.sub.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}"
454454
COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }}
455455
456456
if [[ -n "$COMMAND" ]]; then
457-
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
457+
srun {% if heterogeneous %}--het-group=0 {% endif %}--no-container-mount-home --gpus=0 --overlap {% if command_srun_args %}{{ command_srun_args }} {% endif %}--container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND"
458458
else
459459
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
460460
cat <<EOF >$CLUSTER_DIR/scripts/${SLURM_JOB_ID}-attach.sh

test/run/ray/test_slurm_ray_request.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,24 @@ def test_command_groups_without_resource_group(self):
627627
assert "--overlap" in script
628628
assert "cmd1" in script # Second command in the list (index 1)
629629

630+
def test_command_srun_honors_executor_srun_args(self):
631+
"""Test that the COMMAND launch srun includes executor srun_args."""
632+
executor = SlurmExecutor(account="test_account", srun_args=["--mpi=pmix"])
633+
executor.tunnel = Mock(spec=SSHTunnel)
634+
executor.tunnel.job_dir = "/tmp/test_jobs"
635+
636+
request = SlurmRayRequest(
637+
name="test-ray-cluster",
638+
cluster_dir="/tmp/test_jobs/test-ray-cluster",
639+
template_name="ray.sub.j2",
640+
executor=executor,
641+
command="echo hello",
642+
launch_cmd=["sbatch", "--parsable"],
643+
)
644+
645+
script = request.materialize()
646+
assert "--gpus=0 --overlap --mpi=pmix --container-name=ray-head" in script
647+
630648
def test_env_vars_formatting(self):
631649
"""Test that environment variables are properly formatted as export statements."""
632650
executor = SlurmExecutor(

0 commit comments

Comments
 (0)