Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class ResourceRequest:
env_vars: dict[str, str] = field(default_factory=dict)
srun_args: Optional[list[str]] = None
job_details: SlurmJobDetails = field(default_factory=SlurmJobDetails)
het_group_index: Optional[int] = None

account: str
partition: Optional[str] = None
Expand Down Expand Up @@ -334,6 +335,7 @@ class ResourceRequest:
monitor_group_job: bool = True
monitor_group_job_wait_time: int = 60
setup_lines: Optional[str] = None
het_group_indices: Optional[list[int]] = None

#: Set by the executor; cannot be initialized
job_name: str = field(init=False, default="nemo-job")
Expand All @@ -355,6 +357,21 @@ def merge(

main_executor = executors[0]
main_executor.run_as_group = True

if main_executor.het_group_indices:
assert (
main_executor.heterogeneous
), "heterogeneous must be True if het_group_indices is provided"
assert (
len(main_executor.het_group_indices) == num_tasks
), "het_group_indices must be the same length as the number of tasks"
assert all(
x <= y
for x, y in zip(
main_executor.het_group_indices, main_executor.het_group_indices[1:]
)
), "het_group_indices must be equal or increasing than previous"

main_executor.resource_group = [
cls.ResourceRequest(
packager=copy.deepcopy(main_executor.packager),
Expand All @@ -367,10 +384,13 @@ def merge(
gpus_per_task=main_executor.gpus_per_task,
srun_args=main_executor.srun_args,
job_details=copy.deepcopy(main_executor.job_details),
het_group_index=main_executor.het_group_indices[0]
if main_executor.het_group_indices
else None,
)
]

for executor in executors[1:]:
for i, executor in enumerate(executors[1:]):
main_executor.resource_group.append(
cls.ResourceRequest(
packager=copy.deepcopy(executor.packager),
Expand All @@ -383,6 +403,9 @@ def merge(
gpus_per_task=executor.gpus_per_task,
srun_args=executor.srun_args,
job_details=copy.deepcopy(executor.job_details),
het_group_index=main_executor.het_group_indices[i + 1]
if main_executor.het_group_indices
else None,
)
)

Expand Down Expand Up @@ -803,8 +826,25 @@ def materialize(self) -> str:
sbatch_flags = []
if self.slurm_config.heterogeneous:
assert len(self.jobs) == len(self.slurm_config.resource_group)
final_group_index = len(self.slurm_config.resource_group) - 1
if self.slurm_config.het_group_indices:
final_group_index = self.slurm_config.het_group_indices.index(
max(self.slurm_config.het_group_indices)
)

for i in range(len(self.slurm_config.resource_group)):
resource_req = self.slurm_config.resource_group[i]
if resource_req.het_group_index:
assert (
self.slurm_config.resource_group[i - 1].het_group_index is not None
), "het_group_index must be set for all requests in resource_group"
if (
i > 0
and resource_req.het_group_index
== self.slurm_config.resource_group[i - 1].het_group_index
):
continue

het_parameters = parameters.copy()
het_parameters["output"] = parameters["output"].replace(
original_job_name, self.jobs[i]
Expand All @@ -824,7 +864,7 @@ def materialize(self) -> str:
)
for k in sorted(parameters):
sbatch_flags.append(_as_sbatch_flag(k, het_parameters[k]))
if i != len(self.slurm_config.resource_group) - 1:
if i != final_group_index:
sbatch_flags.append("#SBATCH hetjob")
else:
for k in sorted(parameters):
Expand Down Expand Up @@ -934,7 +974,12 @@ def get_container_flags(
_srun_args.extend(self.slurm_config.srun_args or [])

if self.slurm_config.run_as_group and self.slurm_config.heterogeneous:
het_group_flag = [f"--het-group={group_ind}"]
het_group_index = (
self.slurm_config.resource_group[group_ind].het_group_index
if self.slurm_config.resource_group[group_ind].het_group_index is not None
else group_ind
)
het_group_flag = [f"--het-group={het_group_index}"]
else:
het_group_flag = []

Expand Down
Loading