|
29 | 29 | import requests |
30 | 30 | from invoke.context import Context |
31 | 31 |
|
32 | | -from nemo_run.config import get_nemorun_home |
| 32 | +from nemo_run.config import RUNDIR_NAME, get_nemorun_home |
33 | 33 | from nemo_run.core.execution.base import Executor, ExecutorMacros |
| 34 | +from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun |
| 35 | +from nemo_run.core.execution.utils import fill_template |
| 36 | +from nemo_run.core.frontend.console.api import CONSOLE |
34 | 37 | from nemo_run.core.packaging.base import Packager |
35 | 38 | from nemo_run.core.packaging.git import GitArchivePackager |
36 | 39 |
|
@@ -461,6 +464,24 @@ def cancel(self, job_id: str): |
461 | 464 | response.text, |
462 | 465 | ) |
463 | 466 |
|
| 467 | + def _setup_launcher(self): |
| 468 | + super()._setup_launcher() |
| 469 | + launcher = self.launcher |
| 470 | + if launcher and isinstance(launcher, (FaultTolerance, Torchrun)): |
| 471 | + self.torchrun_nproc_per_node = self.nprocs_per_node |
| 472 | + self.ntasks_per_node = 1 |
| 473 | + CONSOLE.log( |
| 474 | + f"Detected {launcher.__class__.__name__} launcher, setting ntasks_per_node=1 and torchrun_nproc_per_node={self.torchrun_nproc_per_node}" |
| 475 | + ) |
| 476 | + |
| 477 | + if launcher and isinstance(launcher, FaultTolerance): |
| 478 | + base_dir = os.path.join(self.job_dir, Path(self.job_dir).name) |
| 479 | + launcher.cfg_path = os.path.join(base_dir, f"{self.job_name}_ft_cfg.yml") |
| 480 | + launcher.finished_flag_file = os.path.join( |
| 481 | + "/", RUNDIR_NAME, f"{self.job_name}_finished_flag" |
| 482 | + ) |
| 483 | + launcher.job_results_file = os.path.join(base_dir, f"{self.job_name}_job_results") |
| 484 | + |
464 | 485 | def cleanup(self, handle: str): ... |
465 | 486 |
|
466 | 487 | def assign( |
@@ -556,3 +577,55 @@ def _default_headers(self, token: Optional[str] = None) -> dict: |
556 | 577 | if token: |
557 | 578 | headers["Authorization"] = f"Bearer {token}" |
558 | 579 | return headers |
| 580 | + |
| 581 | + |
| 582 | +@dataclass(kw_only=True) |
| 583 | +class DGXCloudRequest: |
| 584 | + launch_cmd: list[str] |
| 585 | + jobs: list[str] |
| 586 | + executor: DGXCloudExecutor |
| 587 | + max_retries: int |
| 588 | + extra_env: dict[str, str] |
| 589 | + launcher: Optional[Launcher] = None |
| 590 | + |
| 591 | + def materialize(self) -> str: |
| 592 | + """Creates the content of a DGXC entrypoint script.""" |
| 593 | + |
| 594 | + # 1. Environment Variables |
| 595 | + # Combine executor defaults with extra envs |
| 596 | + env_vars = [] |
| 597 | + full_env_vars = self.executor.env_vars | self.extra_env |
| 598 | + for key, value in full_env_vars.items(): |
| 599 | + env_vars.append(f"export {key.upper()}={value}") |
| 600 | + |
| 601 | + # 3. Prepare Template Variables |
| 602 | + vars_to_fill = { |
| 603 | + "max_retries": self.max_retries, |
| 604 | + "env_vars": env_vars, |
| 605 | + "training_command": " ".join(self.launch_cmd), |
| 606 | + "ft_enabled": bool(self.launcher and isinstance(self.launcher, FaultTolerance)), |
| 607 | + } |
| 608 | + |
| 609 | + # 4. Fault Tolerance Injection |
| 610 | + if self.launcher and isinstance(self.launcher, FaultTolerance): |
| 611 | + assert ( |
| 612 | + self.launcher.cfg_path |
| 613 | + and self.launcher.finished_flag_file |
| 614 | + and self.launcher.job_results_file |
| 615 | + ), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file" |
| 616 | + |
| 617 | + vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path |
| 618 | + vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file |
| 619 | + vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file |
| 620 | + |
| 621 | + # Render the template |
| 622 | + entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill) |
| 623 | + return entrypoint_script |
| 624 | + |
| 625 | + def __repr__(self) -> str: |
| 626 | + return f"""# DGXC Entrypoint Script Request |
| 627 | +# Executor: {self.executor.__class__.__name__} |
| 628 | +# Jobs: {self.jobs} |
| 629 | +# --------------------------------------------------- |
| 630 | +{self.materialize()} |
| 631 | +""" |
0 commit comments