Skip to content

Commit 4d4f041

Browse files
authored
feat: use slurm executor to get ray template name (#410)
1 parent 2ccf1c9 commit 4d4f041

3 files changed

Lines changed: 25 additions & 12 deletions

File tree

nemo_run/core/execution/slurm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ class ResourceRequest:
344344
het_group_indices: Optional[list[int]] = None
345345
segment: Optional[int] = None
346346
network: Optional[str] = None
347+
#: Template name to use for Ray jobs (e.g., "ray.sub.j2" or "ray_enroot.sub.j2")
348+
ray_template: str = "ray.sub.j2"
347349

348350
#: Set by the executor; cannot be initialized
349351
job_name: str = field(init=False, default="nemo-job")

nemo_run/run/torchx_backend/schedulers/slurm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050
)
5151
from torchx.specs.api import is_terminal
5252

53-
from nemo_run.config import RUNDIR_NAME, USE_WITH_RAY_CLUSTER_KEY, from_dict, get_nemorun_home
53+
from nemo_run.config import (
54+
RUNDIR_NAME,
55+
USE_WITH_RAY_CLUSTER_KEY,
56+
from_dict,
57+
get_nemorun_home,
58+
)
5459
from nemo_run.core.execution.base import Executor
5560
from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor, SlurmJobDetails
5661
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHTunnel, Tunnel
@@ -125,8 +130,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
125130
)
126131

127132
command = [app.roles[0].entrypoint] + app.roles[0].args
128-
# Allow selecting Ray template via environment variable
129-
ray_template_name = os.environ.get("NEMO_RUN_SLURM_RAY_TEMPLATE", "ray.sub.j2")
133+
# Use Ray template from executor configuration
134+
ray_template_name = executor.ray_template
130135
req = SlurmRayRequest(
131136
name=app.roles[0].name,
132137
launch_cmd=["sbatch", "--requeue", "--parsable"],

test/run/torchx_backend/schedulers/test_slurm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor):
368368
mock_tunnel.run.assert_called_once()
369369

370370

371-
def test_ray_template_env_var(slurm_scheduler, slurm_executor):
372-
"""Test that NEMO_RUN_SLURM_RAY_TEMPLATE environment variable selects the correct template."""
371+
def test_ray_template_executor(slurm_scheduler, slurm_executor, temp_dir):
372+
"""Test that executor.ray_template selects the correct template."""
373373
from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY
374374
from nemo_run.run.ray.slurm import SlurmRayRequest
375375

@@ -387,20 +387,26 @@ def test_ray_template_env_var(slurm_scheduler, slurm_executor):
387387
):
388388
slurm_scheduler.tunnel = mock.MagicMock()
389389

390-
# Test default template name
390+
# Test default template name (ray.sub.j2)
391+
assert slurm_executor.ray_template == "ray.sub.j2"
391392
with mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill:
392393
mock_fill.return_value = "#!/bin/bash\n# Mock script"
393394
dryrun_info = slurm_scheduler._submit_dryrun(app_def, slurm_executor)
394395
assert isinstance(dryrun_info.request, SlurmRayRequest)
395396
assert dryrun_info.request.template_name == "ray.sub.j2"
396397

397-
# Test custom template name via environment variable
398-
with (
399-
mock.patch.dict(os.environ, {"NEMO_RUN_SLURM_RAY_TEMPLATE": "ray_enroot.sub.j2"}),
400-
mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill,
401-
):
398+
# Test custom template name via executor
399+
custom_executor = SlurmExecutor(
400+
account="test_account",
401+
job_dir=temp_dir,
402+
nodes=1,
403+
ntasks_per_node=1,
404+
tunnel=LocalTunnel(job_dir=temp_dir),
405+
ray_template="ray_enroot.sub.j2",
406+
)
407+
with mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill:
402408
mock_fill.return_value = "#!/bin/bash\n# Mock script"
403-
dryrun_info = slurm_scheduler._submit_dryrun(app_def, slurm_executor)
409+
dryrun_info = slurm_scheduler._submit_dryrun(app_def, custom_executor)
404410
assert isinstance(dryrun_info.request, SlurmRayRequest)
405411
assert dryrun_info.request.template_name == "ray_enroot.sub.j2"
406412

0 commit comments

Comments
 (0)