diff --git a/src/nemo_run/core/execution/slurm.py b/src/nemo_run/core/execution/slurm.py index bde336da..6ee6e67f 100644 --- a/src/nemo_run/core/execution/slurm.py +++ b/src/nemo_run/core/execution/slurm.py @@ -816,7 +816,7 @@ def materialize(self) -> str: ) het_parameters.update( { - "job_name": f"{self.slurm_config.account}-{self.slurm_config.account.split('_')[-1]}.{self.jobs[i]}", + "job_name": f"{job_details.job_name[:-2] if job_details.job_name.endswith('-0') else job_details.job_name}-{i}", "nodes": resource_req.nodes, "ntasks_per_node": resource_req.ntasks_per_node, "gpus_per_node": resource_req.gpus_per_node, @@ -995,7 +995,7 @@ def get_container_flags( return sbatch_script def __repr__(self) -> str: - return f"""{' '.join(self.cmd + ['$SBATCH_SCRIPT'])} + return f"""{" ".join(self.cmd + ["$SBATCH_SCRIPT"])} #---------------- # SBATCH_SCRIPT diff --git a/test/core/execution/test_slurm.py b/test/core/execution/test_slurm.py index 8ab4cad1..b858bc7a 100644 --- a/test/core/execution/test_slurm.py +++ b/test/core/execution/test_slurm.py @@ -473,16 +473,16 @@ class CustomJobDetails(SlurmJobDetails): @property def stdout(self) -> Path: assert self.folder - return Path(self.folder / "sbatch_job.out") + return Path(self.folder) / "sbatch_job.out" @property def srun_stdout(self) -> Path: assert self.folder - return Path(self.folder / "log_job.out") + return Path(self.folder) / "log_job.out" dummy_slurm_request, _ = dummy_slurm_request_with_artifact dummy_slurm_request.slurm_config.job_details = CustomJobDetails( - job_name="custom_sample_job", folder=Path("/custom_folder") + job_name="custom_sample_job", folder="/custom_folder" ) sbatch_script = dummy_slurm_request.materialize() assert "#SBATCH --job-name=custom_sample_job" in sbatch_script @@ -633,25 +633,23 @@ class CustomJobDetails(SlurmJobDetails): @property def stdout(self) -> Path: assert self.folder - return Path(self.folder / "sbatch_job.out") + return Path(self.folder) / "sbatch_job.out" @property def srun_stdout(self) -> Path: assert self.folder - return Path(self.folder / f"log_{self.job_name}.out") + return Path(self.folder) / f"log_{self.job_name}.out" group_resource_req_slurm_request, _ = group_resource_req_slurm_request_with_artifact group_resource_req_slurm_request.slurm_config.job_details = CustomJobDetails( - job_name="custom_sample_job", folder=Path("/custom_folder") + job_name="custom_sample_job", folder="/custom_folder" ) group_resource_req_slurm_request.slurm_config.resource_group[0].job_details = copy.deepcopy( group_resource_req_slurm_request.slurm_config.job_details ) group_resource_req_slurm_request.slurm_config.resource_group[ 1 - ].job_details = CustomJobDetails( - job_name="custom_sample_job_2", folder=Path("/custom_folder_2") - ) + ].job_details = CustomJobDetails(job_name="custom_sample_job_2", folder="/custom_folder_2") sbatch_script = group_resource_req_slurm_request.materialize() assert "#SBATCH --job-name=custom_sample_job" in sbatch_script @@ -680,3 +678,41 @@ def test_ft_het_slurm_request_materialize( sbatch_script = re.sub(r"--rdzv-id \d+", "--rdzv-id 1", sbatch_script) expected = re.sub(r"--rdzv-id \d+", "--rdzv-id 1", expected) assert sbatch_script.strip() == expected.strip() + + def test_het_job_name_prefix(self, het_slurm_request_with_artifact): + # Set the job_name_prefix to a custom value + het_request, _ = het_slurm_request_with_artifact + het_request.slurm_config.job_name_prefix = "prefix_" + + # Materialize the batch request script + sbatch_script = het_request.materialize() + + # For each job in the heterogeneous request, verify the job name uses the prefix + for job in het_request.jobs: + expected = f"prefix_{job}" + assert expected in sbatch_script, f"Expected job name '{expected}' not found in script" + + def test_het_job_custom_details_job_name(self, het_slurm_request_with_artifact): + # Test that the job name from CustomJobDetails is used for heterogeneous slurm requests + from nemo_run.core.execution.slurm import SlurmJobDetails + + het_request, _ = het_slurm_request_with_artifact + + class CustomJobDetails(SlurmJobDetails): + @property + def stdout(self): + assert self.folder + return Path(self.folder) / "sbatch_job.out" + + @property + def srun_stdout(self): + assert self.folder + return Path(self.folder) / "log_job.out" + + custom_name = "custom_het_job" + het_request.slurm_config.job_details = CustomJobDetails( + job_name=custom_name, folder="/custom_folder" + ) + sbatch_script = het_request.materialize() + for i in range(len(het_request.jobs)): + assert f"#SBATCH --job-name={custom_name}-{i}" in sbatch_script