diff --git a/src/nemo_run/core/execution/slurm.py b/src/nemo_run/core/execution/slurm.py index 623ce361..a7aebfd2 100644 --- a/src/nemo_run/core/execution/slurm.py +++ b/src/nemo_run/core/execution/slurm.py @@ -293,6 +293,7 @@ class ResourceRequest: env_vars: dict[str, str] = field(default_factory=dict) srun_args: Optional[list[str]] = None job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails) + het_group_index: Optional[int] = None account: str partition: Optional[str] = None @@ -334,6 +335,7 @@ class ResourceRequest: monitor_group_job: bool = True monitor_group_job_wait_time: int = 60 setup_lines: Optional[str] = None + het_group_indices: Optional[list[int]] = None #: Set by the executor; cannot be initialized job_name: str = field(init=False, default="nemo-job") @@ -355,6 +357,21 @@ def merge( main_executor = executors[0] main_executor.run_as_group = True + + if main_executor.het_group_indices: + assert ( + main_executor.heterogeneous + ), "heterogeneous must be True if het_group_indices is provided" + assert ( + len(main_executor.het_group_indices) == num_tasks + ), "het_group_indices must be the same length as the number of tasks" + assert all( + x <= y + for x, y in zip( + main_executor.het_group_indices, main_executor.het_group_indices[1:] + ) + ), "het_group_indices must be equal or increasing than previous" + main_executor.resource_group = [ cls.ResourceRequest( packager=copy.deepcopy(main_executor.packager), @@ -367,10 +384,13 @@ def merge( gpus_per_task=main_executor.gpus_per_task, srun_args=main_executor.srun_args, job_details=copy.deepcopy(main_executor.job_details), + het_group_index=main_executor.het_group_indices[0] + if main_executor.het_group_indices + else None, ) ] - for executor in executors[1:]: + for i, executor in enumerate(executors[1:]): main_executor.resource_group.append( cls.ResourceRequest( packager=copy.deepcopy(executor.packager), @@ -383,6 +403,9 @@ def merge( gpus_per_task=executor.gpus_per_task, srun_args=executor.srun_args, job_details=copy.deepcopy(executor.job_details), + het_group_index=main_executor.het_group_indices[i + 1] + if main_executor.het_group_indices + else None, ) ) @@ -803,8 +826,25 @@ def materialize(self) -> str: sbatch_flags = [] if self.slurm_config.heterogeneous: assert len(self.jobs) == len(self.slurm_config.resource_group) + final_group_index = len(self.slurm_config.resource_group) - 1 + if self.slurm_config.het_group_indices: + final_group_index = self.slurm_config.het_group_indices.index( + max(self.slurm_config.het_group_indices) + ) + for i in range(len(self.slurm_config.resource_group)): resource_req = self.slurm_config.resource_group[i] + if resource_req.het_group_index: + assert ( + self.slurm_config.resource_group[i - 1].het_group_index is not None + ), "het_group_index must be set for all requests in resource_group" + if ( + i > 0 + and resource_req.het_group_index + == self.slurm_config.resource_group[i - 1].het_group_index + ): + continue + het_parameters = parameters.copy() het_parameters["output"] = parameters["output"].replace( original_job_name, self.jobs[i] @@ -824,7 +864,7 @@ def materialize(self) -> str: ) for k in sorted(parameters): sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k])) - if i != len(self.slurm_config.resource_group) - 1: + if i != final_group_index: sbatch_flags.append("#SBATCH hetjob") else: for k in sorted(parameters): @@ -934,7 +974,12 @@ def get_container_flags( _srun_args.extend(self.slurm_config.srun_args or []) if self.slurm_config.run_as_group and self.slurm_config.heterogeneous: - het_group_flag = [f"--het-group={group_ind}"] + het_group_index = ( + self.slurm_config.resource_group[group_ind].het_group_index + if self.slurm_config.resource_group[group_ind].het_group_index is not None + else group_ind + ) + het_group_flag = [f"--het-group={het_group_index}"] else: het_group_flag = []