Skip to content

Commit 661eaed

Browse files
authored
fix: Pass DGXC to ft_launcher (#402)
* pass dgxc to ft_launcher Signed-off-by: oliver könig <okoenig@nvidia.com> * feat: Add FT to DGXC Signed-off-by: oliver könig <okoenig@nvidia.com> * torchrun_job Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * format Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * revert Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * cleanup Signed-off-by: oliver könig <okoenig@nvidia.com> * change template Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * test Signed-off-by: oliver könig <okoenig@nvidia.com> * retries Signed-off-by: oliver könig <okoenig@nvidia.com> * TORCHX_MAX_RETRIES Signed-off-by: oliver könig <okoenig@nvidia.com> * cleanup Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * bump FT interface Signed-off-by: oliver könig <okoenig@nvidia.com> * --ft-use-infra-group-rank=False Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * add warning for max_restarts Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * add test Signed-off-by: oliver könig <okoenig@nvidia.com> * fix Signed-off-by: oliver könig <okoenig@nvidia.com> * tests Signed-off-by: oliver könig <okoenig@nvidia.com> * fix test Signed-off-by: oliver könig <okoenig@nvidia.com> --------- Signed-off-by: oliver könig <okoenig@nvidia.com>
1 parent 85a2b9c commit 661eaed

12 files changed

Lines changed: 673 additions & 22 deletions

File tree

nemo_run/core/execution/dgxcloud.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@
2929
import requests
3030
from invoke.context import Context
3131

32-
from nemo_run.config import get_nemorun_home
32+
from nemo_run.config import RUNDIR_NAME, get_nemorun_home
3333
from nemo_run.core.execution.base import Executor, ExecutorMacros
34+
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun
35+
from nemo_run.core.execution.utils import fill_template
36+
from nemo_run.core.frontend.console.api import CONSOLE
3437
from nemo_run.core.packaging.base import Packager
3538
from nemo_run.core.packaging.git import GitArchivePackager
3639

@@ -461,6 +464,24 @@ def cancel(self, job_id: str):
461464
response.text,
462465
)
463466

467+
def _setup_launcher(self):
468+
super()._setup_launcher()
469+
launcher = self.launcher
470+
if launcher and isinstance(launcher, (FaultTolerance, Torchrun)):
471+
self.torchrun_nproc_per_node = self.nprocs_per_node
472+
self.ntasks_per_node = 1
473+
CONSOLE.log(
474+
f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}"
475+
)
476+
477+
if launcher and isinstance(launcher, FaultTolerance):
478+
base_dir = os.path.join(self.job_dir, Path(self.job_dir).name)
479+
launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml")
480+
launcher.finished_flag_file = os.path.join(
481+
"/", RUNDIR_NAME, f"{self.job_name}_finished_flag"
482+
)
483+
launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results")
484+
464485
def cleanup(self, handle: str): ...
465486

466487
def assign(
@@ -556,3 +577,55 @@ def _default_headers(self, token: Optional[str] = None) -> dict:
556577
if token:
557578
headers["Authorization"] = f"Bearer {token}"
558579
return headers
580+
581+
582+
@dataclass(kw_only=True)
583+
class DGXCloudRequest:
584+
launch_cmd: list[str]
585+
jobs: list[str]
586+
executor: DGXCloudExecutor
587+
max_retries: int
588+
extra_env: dict[str, str]
589+
launcher: Optional[Launcher] = None
590+
591+
def materialize(self) -> str:
592+
"""Creates the content of a DGXC entrypoint script."""
593+
594+
# 1. Environment Variables
595+
# Combine executor defaults with extra envs
596+
env_vars = []
597+
full_env_vars = self.executor.env_vars | self.extra_env
598+
for key, value in full_env_vars.items():
599+
env_vars.append(f"export {key.upper()}={value}")
600+
601+
# 3. Prepare Template Variables
602+
vars_to_fill = {
603+
"max_retries": self.max_retries,
604+
"env_vars": env_vars,
605+
"training_command": " ".join(self.launch_cmd),
606+
"ft_enabled": bool(self.launcher and isinstance(self.launcher, FaultTolerance)),
607+
}
608+
609+
# 4. Fault Tolerance Injection
610+
if self.launcher and isinstance(self.launcher, FaultTolerance):
611+
assert (
612+
self.launcher.cfg_path
613+
and self.launcher.finished_flag_file
614+
and self.launcher.job_results_file
615+
), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file"
616+
617+
vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path
618+
vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file
619+
vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file
620+
621+
# Render the template
622+
entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill)
623+
return entrypoint_script
624+
625+
def __repr__(self) -> str:
626+
return f"""# DGXC Entrypoint Script Request
627+
# Executor: {self.executor.__class__.__name__}
628+
# Jobs: {self.jobs}
629+
# ---------------------------------------------------
630+
{self.materialize()}
631+
"""
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{%- import "ft_launcher_dgxc.j2" as fault_tolerance -%}
2+
#!/bin/bash
3+
4+
set -evx # Print commands, but DO NOT exit immediately on error (we handle that below)
5+
export PYTHONUNBUFFERED=1
6+
export TORCHX_MAX_RETRIES={{max_retries}}
7+
8+
{%- for env_var in env_vars %}
9+
{{env_var}}
10+
{%- endfor %}
11+
12+
{%- if ft_enabled %}
13+
{{ fault_tolerance.ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) }}
14+
{%- endif %}
15+
16+
echo "Starting training command..."
17+
set +e # Turn off auto-exit so we can capture the code
18+
19+
{{ training_command }}
20+
21+
exitcode=$?
22+
set -e
23+
24+
echo "Main command exited with code $exitcode"
25+
26+
{%- if ft_enabled %}
27+
{{ fault_tolerance.ft_launcher_teardown() }}
28+
{%- else %}
29+
30+
exit $exitcode
31+
{%- endif %}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{% macro ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) -%}
2+
# This script uses experimental fault tolerance launcher
3+
# Fault tolerance related items
4+
export FAULT_TOL_CFG_PATH="{{fault_tol_cfg_path}}"
5+
export FAULT_TOL_FINISHED_FLAG_FILE="{{fault_tol_finished_flag_file}}"
6+
7+
JOB_RESULTS_FILE="{{fault_tol_job_results_file}}"
8+
9+
is_training_finished() {
10+
test -f "$(dirname $JOB_RESULTS_FILE)/$(basename $FAULT_TOL_FINISHED_FLAG_FILE)"
11+
}
12+
13+
if is_training_finished ; then
14+
echo "Training is finished";
15+
exit 0;
16+
else
17+
rm -f "$FAULT_TOL_FINISHED_FLAG_FILE" "$JOB_RESULTS_FILE"
18+
fi
19+
20+
{%- endmacro %}
21+
22+
{% macro ft_launcher_teardown() -%}
23+
exit $exitcode
24+
{%- endmacro %}
File renamed without changes.

nemo_run/core/execution/templates/slurm.sh.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{%- import "ft_launcher.j2" as fault_tolerance -%}
1+
{%- import "ft_launcher_slurm.j2" as fault_tolerance -%}
22
#!/bin/bash
33
#
44
# Generated by NeMo Run

nemo_run/run/torchx_backend/components/ft_launcher.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import shlex
1718
from typing import Optional
1819

@@ -22,6 +23,8 @@
2223

2324
from nemo_run.run.torchx_backend.components import torchrun
2425

26+
logger = logging.getLogger(__name__)
27+
2528

2629
# Adapted from torchrun component
2730
def ft_launcher(
@@ -92,30 +95,36 @@ def ft_launcher(
9295
):
9396
if workload_check_interval:
9497
ft_args += [
95-
"--ft-param-workload_check_interval",
98+
"--ft-workload_check_interval",
9699
str(workload_check_interval),
97100
]
98101

99102
if initial_rank_heartbeat_timeout:
100103
ft_args += [
101-
"--ft-param-initial_rank_heartbeat_timeout",
104+
"--ft-initial_rank_heartbeat_timeout",
102105
str(initial_rank_heartbeat_timeout),
103106
]
104107

105108
if rank_heartbeat_timeout:
106109
ft_args += [
107-
"--ft-param-rank_heartbeat_timeout",
110+
"--ft-rank_heartbeat_timeout",
108111
str(rank_heartbeat_timeout),
109112
]
110113

111114
if rank_termination_signal:
112-
ft_args += ["--ft-param-rank_termination_signal", rank_termination_signal]
115+
ft_args += ["--ft-rank_termination_signal", rank_termination_signal]
113116

114117
if log_level:
115-
ft_args += ["--ft-param-log_level", log_level]
118+
ft_args += ["--ft-log_level", log_level]
116119

117120
if max_restarts:
118-
ft_args += ["--max-restarts", str(max_restarts)]
121+
if dgxc is True:
122+
logger.warning("max_restarts is ignored for DGXCloudExecutor")
123+
else:
124+
ft_args += ["--max-restarts", str(max_restarts)]
125+
126+
if dgxc is True:
127+
ft_args += ["--ft-use-infra-group-rank", "False"]
119128

120129
else:
121130
ft_args = ["--ignore-missing-fault-tol-cfg"]

nemo_run/run/torchx_backend/packaging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
203203
log_level=launcher.log_level,
204204
max_retries=executor.retries,
205205
max_restarts=launcher.max_restarts,
206+
dgxc=isinstance(executor, DGXCloudExecutor),
206207
use_env=use_env,
207208
)
208209
else:

nemo_run/run/torchx_backend/schedulers/dgxcloud.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
from nemo_run.config import get_nemorun_home
3939
from nemo_run.core.execution.base import Executor
40-
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState
40+
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudRequest, DGXCloudState
4141
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
4242
from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin
4343

@@ -109,6 +109,23 @@ def _submit_dryrun( # type: ignore
109109
role = values.apply(role)
110110

111111
cmd = [role.entrypoint] + role.args
112+
113+
req = DGXCloudRequest(
114+
launch_cmd=cmd,
115+
jobs=[role.name],
116+
executor=executor,
117+
max_retries=role.max_retries,
118+
extra_env=role.env,
119+
launcher=executor.get_launcher(),
120+
)
121+
122+
# Write and copy sbatch script
123+
path = os.path.join(executor.experiment_dir, "torchrun_job.sh")
124+
script = req.materialize()
125+
126+
with open(path, "w") as f:
127+
f.write(script)
128+
112129
return AppDryRunInfo(
113130
DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name),
114131
# Minimal function to show the config, if any
@@ -128,7 +145,9 @@ def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str:
128145

129146
# The DGXExecutor's launch call typically returns (job_id, handle).
130147
# We'll call it without additional parameters here.
131-
job_id, status = executor.launch(name=req.name, cmd=req.cmd)
148+
cmd = os.path.join(executor.experiment_dir, "torchrun_job.sh")
149+
req.launch_cmd = ["bash", cmd]
150+
job_id, status = executor.launch(name=req.name, cmd=req.launch_cmd)
132151
if not job_id:
133152
raise RuntimeError("Failed scheduling run on DGX: no job_id returned")
134153

test/core/execution/artifacts/ft_het_slurm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ echo "$SLURM_JOB_ID ${SLURM_RESTART_COUNT:-0} X" >> "$JOB_RESULTS_FILE"
7777
export CUSTOM_ENV_1=some_value_1
7878

7979

80-
srun --het-group=0 --output /root/experiment/sample_job/log-account-account.sample_job-0_%j_${SLURM_RESTART_COUNT:-0}.out --container-image image_1 --container-mounts /root/experiment/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-param-workload_check_interval 10 --ft-param-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 1 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh & pids[0]=$!
80+
srun --het-group=0 --output /root/experiment/sample_job/log-account-account.sample_job-0_%j_${SLURM_RESTART_COUNT:-0}.out --container-image image_1 --container-mounts /root/experiment/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-workload_check_interval 10 --ft-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 1 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh & pids[0]=$!
8181

8282
sleep 30
8383

test/core/execution/artifacts/ft_slurm.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ echo "$SLURM_JOB_ID ${SLURM_RESTART_COUNT:-0} X" >> "$JOB_RESULTS_FILE"
6262

6363
# Command 1
6464

65-
srun --output /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.out --container-mounts /root/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-param-workload_check_interval 10 --ft-param-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 7680 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh
65+
srun --output /root/sample_job/log-account-account.sample_job_%j_${SLURM_RESTART_COUNT:-0}.out --container-mounts /root/sample_job:/nemo_run --container-workdir /nemo_run/code --wait=60 --kill-on-bad-exit=1 ft_launcher --ft-workload_check_interval 10 --ft-rank_heartbeat_timeout 10 --rdzv-backend c10d --rdzv-endpoint localhost:0 --rdzv-id 7680 --nnodes 1 --nproc-per-node 1 --node-rank 0 --tee 3 --no-python test_ft.sh
6666

6767
exitcode=$?
6868

0 commit comments

Comments
 (0)