From 6fae01834e896596d6a6e4695846a17282adc994 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Sun, 8 Jun 2025 11:04:32 -0700 Subject: [PATCH 1/4] Support overlapped srun commands in Slurm Ray Signed-off-by: Hemil Desai --- nemo_run/core/execution/slurm.py | 1 - nemo_run/run/ray/slurm.py | 63 +++++++++++++++++++ nemo_run/run/ray/templates/ray.sub.j2 | 14 +++++ nemo_run/run/torchx_backend/packaging.py | 9 ++- .../run/torchx_backend/schedulers/slurm.py | 13 +++- .../artifacts/group_resource_req_slurm.sh | 1 + 6 files changed, 96 insertions(+), 5 deletions(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 3f662723..ac2093fd 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -414,7 +414,6 @@ def merge( ) ) - main_executor.env_vars = {} return main_executor def __post_init__(self): diff --git a/nemo_run/run/ray/slurm.py b/nemo_run/run/ray/slurm.py index abb63fa4..639247a0 100644 --- a/nemo_run/run/ray/slurm.py +++ b/nemo_run/run/ray/slurm.py @@ -117,6 +117,7 @@ class SlurmRayRequest: command: Optional[str] = None workdir: Optional[str] = None nemo_run_dir: Optional[str] = None + command_groups: Optional[list[list[str]]] = None launch_cmd: list[str] @staticmethod @@ -234,6 +235,60 @@ def get_srun_flags(mounts: list[str], container_image: Optional[str]) -> str: "gres_specification": get_gres_specification(), } + if self.command_groups: + srun_commands: list[str] = [] + group_env_vars: list[list[str]] = [] + + for idx, group in enumerate(self.command_groups): + if idx == 0: + continue + + if self.executor.run_as_group and len(self.executor.resource_group) == len( + self.command_groups + ): + req = self.executor.resource_group[idx] + env_list = [f"export {k.upper()}={v}" for k, v in req.env_vars.items()] + group_env_vars.append(env_list) + container_flags = get_srun_flags(req.container_mounts, req.container_image) + srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"] + srun_args.extend(req.srun_args or []) + else: + container_flags = get_srun_flags( + self.executor.container_mounts, self.executor.container_image + ) + srun_args = ["--wait=60", "--kill-on-bad-exit=1", "--overlap"] + srun_args.extend(self.executor.srun_args or []) + group_env_vars.append([]) + + stdout_path = os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.out") + stderr_flags = [] + if not self.executor.stderr_to_stdout: + stderr_flags = [ + "--error", + os.path.join(self.cluster_dir, "logs", f"ray-overlap-{idx}.err"), + ] + + srun_cmd = " ".join( + list( + map( + lambda arg: arg if isinstance(arg, noquote) else shlex.quote(arg), + [ + "srun", + "--output", + noquote(stdout_path), + *stderr_flags, + container_flags, + *srun_args, + ], + ) + ) + ) + command = " ".join(group) + srun_commands.append(f"{srun_cmd} {command} &") + + vars_to_fill["srun_commands"] = srun_commands + vars_to_fill["group_env_vars"] = group_env_vars + if self.pre_ray_start_commands: vars_to_fill["pre_ray_start_commands"] = "\n".join(self.pre_ray_start_commands) @@ -398,6 +453,7 @@ def create( dryrun: bool = False, command: Optional[str] = None, workdir: Optional[str] = None, + command_groups: Optional[list[list[str]]] = None, ) -> Any: """Create (or reuse) a Slurm-backed Ray cluster and return its job-id. @@ -416,6 +472,9 @@ def create( Optional command executed after the Ray head node is ready (e.g. ``ray job submit``). workdir : str | None Remote working directory that becomes the CWD inside the container. + command_groups : list[list[str]] | None + Additional commands (one per group) executed via ``srun`` with ``--overlap`` + after the cluster is started. Returns ------- @@ -433,6 +492,7 @@ def create( pre_ray_start_commands=pre_ray_start_commands, command=command, workdir=workdir, + command_groups=command_groups, launch_cmd=["sbatch", "--requeue", "--parsable", "--dependency=singleton"], ).materialize() @@ -1094,6 +1154,7 @@ def start( runtime_env_yaml: Optional[str] | None = None, pre_ray_start_commands: Optional[list[str]] = None, dryrun: bool = False, + command_groups: Optional[list[list[str]]] = None, ): """Submit a Ray job via Slurm and return a *live* SlurmRayJob helper. @@ -1106,6 +1167,7 @@ def start( executor=my_slurm_executor, command="python train.py", workdir="./src", + command_groups=[["echo", "hello"]], ) """ # ------------------------------------------------------------------ @@ -1212,6 +1274,7 @@ def start( dryrun=dryrun, command=command, workdir=remote_workdir, + command_groups=command_groups, ) self.job_id = job_id diff --git a/nemo_run/run/ray/templates/ray.sub.j2 b/nemo_run/run/ray/templates/ray.sub.j2 index 4375e15d..6d0eed1c 100644 --- a/nemo_run/run/ray/templates/ray.sub.j2 +++ b/nemo_run/run/ray/templates/ray.sub.j2 @@ -295,6 +295,20 @@ echo "[INFO] Ray cluster information saved to $CLUSTER_DIR/ray_cluster_info.json ######################################################## +{% if srun_commands %} +# Run extra commands +{% for srun_command in srun_commands %} +{%- if loop.index <= group_env_vars|length %} +{%- for env_var in group_env_vars[loop.index - 1] %} +{{env_var}} +{%- endfor %} +{%- endif %} + +{{srun_command}} +{% endfor %} +######################################################## +{% endif -%} + # We can now launch a job on this cluster # We do so by launching a driver process on the physical node that the head node is on # This driver process is responsible for launching a job on the Ray cluster diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index e915e9b0..e80247de 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -233,7 +233,6 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): assert isinstance(executor, SlurmExecutor), ( f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" ) - assert len(app_def.roles) == 1, "Only one command is supported for Ray jobs." app_def.metadata = metadata return app_def @@ -241,7 +240,13 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): def merge_executables(app_defs: Iterator[specs.AppDef], name: str) -> specs.AppDef: result = specs.AppDef(name=name, roles=[]) - for app_def in app_defs: + result.metadata = {} + for idx, app_def in enumerate(app_defs): + metadata = app_def.metadata or {} + if USE_WITH_RAY_CLUSTER_KEY in metadata: + assert idx == 0, f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for the first command" + + result.metadata.update(app_def.metadata) result.roles.extend(app_def.roles) return result diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index b65e0504..48c68084 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -102,8 +102,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t executor.package(packager=executor.packager, job_name=Path(job_dir).name) + values = executor.macro_values() + if app.metadata and app.metadata.get(USE_WITH_RAY_CLUSTER_KEY, False): - assert len(app.roles) == 1, "Only one command is supported for Ray jobs." + srun_cmds: list[list[str]] = [] + + for role in app.roles: + if values: + role = values.apply(role) + srun_cmd = [role.entrypoint] + role.args + srun_cmds.append([" ".join(srun_cmd)]) + command = [app.roles[0].entrypoint] + app.roles[0].args req = SlurmRayRequest( name=app.roles[0].name, @@ -114,12 +123,12 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t executor=executor, workdir=f"/{RUNDIR_NAME}/code", nemo_run_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name), + command_groups=srun_cmds, ) else: srun_cmds: list[list[str]] = [] jobs = [] envs = {} - values = executor.macro_values() if values: executor.env_vars = { diff --git a/test/core/execution/artifacts/group_resource_req_slurm.sh b/test/core/execution/artifacts/group_resource_req_slurm.sh index 3bbde92e..a68b77d9 100644 --- a/test/core/execution/artifacts/group_resource_req_slurm.sh +++ b/test/core/execution/artifacts/group_resource_req_slurm.sh @@ -30,6 +30,7 @@ nodes_array=($nodes) head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +export CUSTOM_ENV_1=some_value_1 export ENV_VAR=value From f21726420a3c9b80eecf24f36dffded41b58bfe2 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Sun, 8 Jun 2025 11:10:52 -0700 Subject: [PATCH 2/4] test Signed-off-by: Hemil Desai --- test/run/ray/test_slurm_ray_request.py | 176 +++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/test/run/ray/test_slurm_ray_request.py b/test/run/ray/test_slurm_ray_request.py index 32a96734..6609872e 100644 --- a/test/run/ray/test_slurm_ray_request.py +++ b/test/run/ray/test_slurm_ray_request.py @@ -405,3 +405,179 @@ def test_array_assertion(self): with pytest.raises(AssertionError, match="array is not supported"): request.materialize() + + def test_command_groups_env_vars(self): + """Test environment variables are properly set for each command group.""" + # Create executor with environment variables + executor = SlurmExecutor( + account="test_account", + env_vars={"GLOBAL_ENV": "global_value"}, + ) + executor.run_as_group = True + + # Create resource groups with different env vars + resource_group = [ + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image1", + env_vars={"GROUP1_ENV": "group1_value"}, + container_mounts=["/mount1"], + ), + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=1, + container_image="image2", + env_vars={"GROUP2_ENV": "group2_value"}, + container_mounts=["/mount2"], + ), + ] + executor.resource_group = resource_group + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[["cmd0"], ["cmd1"], ["cmd2"]], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Check global env vars are set in setup section + assert "export GLOBAL_ENV=global_value" in script + + # Check that command groups generate srun commands (excluding the first one) + # The template should have a section for srun_commands + assert "# Run extra commands" in script + assert "srun" in script + assert "cmd1" in script # First command group after skipping index 0 + assert "cmd2" in script # Second command group + + def test_command_groups_without_resource_group(self): + """Test command groups work without resource groups.""" + executor = SlurmExecutor( + account="test_account", + env_vars={"GLOBAL_ENV": "global_value"}, + ) + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + command_groups=[["cmd0"], ["cmd1"]], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Should have global env vars + assert "export GLOBAL_ENV=global_value" in script + + # Should have srun commands for overlapping groups (skipping first) + assert "srun" in script + assert "--overlap" in script + assert "cmd1" in script # Second command in the list (index 1) + + def test_env_vars_formatting(self): + """Test that environment variables are properly formatted as export statements.""" + executor = SlurmExecutor( + account="test_account", + env_vars={ + "VAR_WITH_SPACES": "value with spaces", + "PATH_VAR": "/usr/bin:/usr/local/bin", + "EMPTY_VAR": "", + "NUMBER_VAR": "123", + }, + ) + executor.tunnel = Mock(spec=SSHTunnel) + executor.tunnel.job_dir = "/tmp/test_jobs" + + request = SlurmRayRequest( + name="test-ray-cluster", + cluster_dir="/tmp/test_jobs/test-ray-cluster", + template_name="ray.sub.j2", + executor=executor, + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Check all environment variables are properly exported + assert "export VAR_WITH_SPACES=value with spaces" in script + assert "export PATH_VAR=/usr/bin:/usr/local/bin" in script + assert "export EMPTY_VAR=" in script + assert "export NUMBER_VAR=123" in script + + def test_group_env_vars_integration(self): + """Test full integration of group environment variables matching the artifact pattern.""" + # This test verifies the behavior seen in group_resource_req_slurm.sh + executor = SlurmExecutor( + account="your_account", + partition="your_partition", + time="00:30:00", + nodes=1, + ntasks_per_node=8, + gpus_per_node=8, + container_image="some-image", + container_mounts=["/some/job/dir/sample_job:/nemo_run"], + env_vars={"ENV_VAR": "value"}, + ) + executor.run_as_group = True + + # Set up resource groups with specific env vars + resource_group = [ + # First group (index 0) - for the head/main command + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + container_image="some-image", + env_vars={"CUSTOM_ENV_1": "some_value_1"}, + container_mounts=["/some/job/dir/sample_job:/nemo_run"], + ), + # Second group (index 1) + SlurmExecutor.ResourceRequest( + packager=Mock(), + nodes=1, + ntasks_per_node=8, + container_image="different_container_image", + env_vars={"CUSTOM_ENV_1": "some_value_1"}, + container_mounts=["/some/job/dir/sample_job:/nemo_run"], + ), + ] + executor.resource_group = resource_group + + # Mock tunnel + tunnel_mock = Mock(spec=SSHTunnel) + tunnel_mock.job_dir = "/some/job/dir" + executor.tunnel = tunnel_mock + + request = SlurmRayRequest( + name="sample_job", + cluster_dir="/some/job/dir/sample_job", + template_name="ray.sub.j2", + executor=executor, + command_groups=[ + ["bash ./scripts/start_server.sh"], + ["bash ./scripts/echo.sh server_host=$het_group_host_0"], + ], + launch_cmd=["sbatch", "--parsable"], + ) + + script = request.materialize() + + # Verify the pattern matches the artifact: + # 1. Global env vars should be exported in setup + assert "export ENV_VAR=value" in script + + # The template should include group_env_vars for proper env var handling per command + # (The actual env var exports per command happen in the template rendering) From e24c10fdd9a6450696a340ed39daeb17275e2437 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 11 Jun 2025 08:14:00 -0700 Subject: [PATCH 3/4] Fix Signed-off-by: Hemil Desai --- nemo_run/run/torchx_backend/packaging.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index e80247de..ca88f634 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -230,9 +230,9 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): if metadata: if USE_WITH_RAY_CLUSTER_KEY in metadata: - assert isinstance(executor, SlurmExecutor), ( - f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" - ) + assert isinstance( + executor, SlurmExecutor + ), f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" app_def.metadata = metadata return app_def @@ -246,7 +246,7 @@ def merge_executables(app_defs: Iterator[specs.AppDef], name: str) -> specs.AppD if USE_WITH_RAY_CLUSTER_KEY in metadata: assert idx == 0, f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for the first command" - result.metadata.update(app_def.metadata) + result.metadata.update(metadata) result.roles.extend(app_def.roles) return result From 1a267ae7f0f5fb05def51d06ea25d18872a57923 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 11 Jun 2025 08:28:42 -0700 Subject: [PATCH 4/4] fmt Signed-off-by: Hemil Desai --- nemo_run/run/torchx_backend/packaging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo_run/run/torchx_backend/packaging.py b/nemo_run/run/torchx_backend/packaging.py index ca88f634..99567fd0 100644 --- a/nemo_run/run/torchx_backend/packaging.py +++ b/nemo_run/run/torchx_backend/packaging.py @@ -230,9 +230,9 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool): if metadata: if USE_WITH_RAY_CLUSTER_KEY in metadata: - assert isinstance( - executor, SlurmExecutor - ), f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" + assert isinstance(executor, SlurmExecutor), ( + f"{USE_WITH_RAY_CLUSTER_KEY} is only supported for SlurmExecutor" + ) app_def.metadata = metadata return app_def