Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/nemo_run/core/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
filenames.append(filename)
return filenames

def create_job_dir(self):
os.makedirs(self.job_dir, exist_ok=True)

def cleanup(self, handle: str): ...


Expand Down
1 change: 0 additions & 1 deletion src/nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def assign(
self.experiment_dir = exp_dir
self.job_dir = os.path.join(exp_dir, task_dir)
self.experiment_id = exp_id
os.makedirs(self.job_dir, exist_ok=True)
assert any(
map(
lambda x: os.path.commonpath(
Expand Down
1 change: 0 additions & 1 deletion src/nemo_run/core/execution/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def assign(
self.experiment_id = exp_id
self.experiment_dir = exp_dir
self.job_dir = os.path.join(exp_dir, task_dir)
os.makedirs(self.job_dir, exist_ok=True)

def nnodes(self) -> int:
return 1
Expand Down
1 change: 0 additions & 1 deletion src/nemo_run/core/execution/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def assign(
self.experiment_id = exp_id
self.experiment_dir = exp_dir
self.job_dir = os.path.join(exp_dir, task_dir)
os.makedirs(self.job_dir, exist_ok=True)

def nnodes(self) -> int:
return 1
Expand Down
2 changes: 0 additions & 2 deletions src/nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ def assign(
self.job_dir = os.path.join(exp_dir, task_dir)
self.experiment_id = exp_id

os.makedirs(self.job_dir, exist_ok=True)

def package(self, packager: Packager, job_name: str):
assert self.experiment_id, "Executor not assigned to an experiment."
if isinstance(packager, GitArchivePackager):
Expand Down
1 change: 0 additions & 1 deletion src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,6 @@ def assign(
self.job_dir = os.path.join(exp_dir, task_dir)
self.experiment_id = exp_id

os.makedirs(self.job_dir, exist_ok=True)
self.tunnel._set_job_dir(self.experiment_id)

def get_launcher_prefix(self) -> Optional[list[str]]:
Expand Down
41 changes: 24 additions & 17 deletions src/nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,7 @@ def __init__(
self._runner = get_runner()

if not _reconstruct:
os.makedirs(self._exp_dir, exist_ok=False)

self.executor = executor if executor else LocalExecutor()
self._save_config()
else:
assert isinstance(executor, Executor)
self.executor = executor
Expand All @@ -334,6 +331,10 @@ def to_config(self) -> Config:
log_level=self.log_level,
)

def _save_experiment(self, exist_ok: bool = False):
os.makedirs(self._exp_dir, exist_ok=exist_ok)
self._save_config()

def _save_config(self):
with open(os.path.join(self._exp_dir, self.__class__._CONFIG_FILE), "w+") as f:
f.write(ZlibJSONSerializer().serialize(self.to_config()))
Expand Down Expand Up @@ -389,6 +390,13 @@ def _load_jobs(self) -> list[Job | JobGroup]:

return jobs

def _prepare(self, exist_ok: bool = False):
self._save_experiment(exist_ok=exist_ok)
for job in self.jobs:
job.prepare()

self._save_jobs()

def _add_single_job(
self,
task: Union[Partial, Script],
Expand Down Expand Up @@ -434,7 +442,6 @@ def _add_single_job(
plugin.assign(self._id)
plugin.setup(cloned, executor)

job.prepare()
self._jobs.append(job)
return job.id

Expand Down Expand Up @@ -482,7 +489,6 @@ def _add_job_group(
assert isinstance(_executor, Executor)
plugin.setup(task, _executor)

job_group.prepare()
self._jobs.append(job_group)
return job_group.id

Expand Down Expand Up @@ -552,16 +558,17 @@ def add(
dependencies=dependencies.copy() if dependencies else None,
)

self._save_jobs()
return job_id

def dryrun(self, log: bool = True):
def dryrun(self, log: bool = True, exist_ok: bool = False, delete_exp_dir: bool = True):
"""
Logs the raw scripts that will be executed for each task.
"""
if log:
self.console.log(f"[bold magenta]Experiment {self._id} dryrun...")

self._prepare(exist_ok=exist_ok)

for job in self.jobs:
if isinstance(job, Job):
if log:
Expand All @@ -571,6 +578,9 @@ def dryrun(self, log: bool = True):
self.console.log(f"[bold magenta]Task Group {job.id}\n")
job.launch(wait=False, runner=self._runner, dryrun=True, direct=False, log_dryrun=log)

if delete_exp_dir:
shutil.rmtree(self._exp_dir)

def run(
self,
sequential: bool = False,
Expand Down Expand Up @@ -614,6 +624,9 @@ def run(
self.console.log("[bold magenta]Experiment in inspection mode...")
return

# Prepare experiment before running
self._prepare()

if direct:
self.console.log(
"[bold magenta]Running the experiment with direct=True. "
Expand All @@ -637,8 +650,8 @@ def run(
os.path.join(job.executor.job_dir, f"log_{job.id}_direct_run.out")
):
job.launch(wait=True, direct=True, runner=self._runner)
self._save_jobs()

self._save_jobs()
self._launched = any(map(lambda job: job.launched, self.jobs))
self._direct = True
return
Expand Down Expand Up @@ -669,7 +682,7 @@ def run(
for i in range(1, len(self.jobs)):
self.jobs[i].dependencies.append(self.jobs[i - 1].id)

self.dryrun(log=False)
self.dryrun(log=False, exist_ok=True, delete_exp_dir=False)
for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
Expand Down Expand Up @@ -746,14 +759,14 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
job.executor.dependencies = deps # type: ignore
job.launch(wait=False, runner=self._runner)

self._save_jobs()
except Exception as e:
self.console.log(f"Error running job {job.id}: {e}")
raise e

if wait:
self._wait_for_jobs(jobs=[job_map[node] for node in level])

self._save_jobs()
self._launched = any(map(lambda job: job.launched, self.jobs))
self._waited = wait

Expand Down Expand Up @@ -955,7 +968,6 @@ def reset(self) -> "Experiment":
old_id, old_exp_dir, old_launched = self._id, self._exp_dir, self._launched
self._id = f"{self._title}_{int(time.time())}"
self._exp_dir = os.path.join(NEMORUN_HOME, "experiments", self._title, self._id)
os.makedirs(self._exp_dir, exist_ok=False)
self._launched = False
self._live_progress = None

Expand All @@ -967,12 +979,9 @@ def reset(self) -> "Experiment":
_current_experiment.set(self)
_set_current_experiment = True

if "__main__.py" in os.listdir(old_exp_dir):
shutil.copy(os.path.join(old_exp_dir, "__main__.py"), self._exp_dir)

try:
if "__external_main__" not in sys.modules:
maybe_load_external_main(self._exp_dir)
maybe_load_external_main(old_exp_dir)

for job in jobs:
if isinstance(job, Job):
Expand Down Expand Up @@ -1022,8 +1031,6 @@ def reset(self) -> "Experiment":
self._current_experiment_token = None

self._reconstruct = False
self._save_config()

return self

def _initialize_live_progress(self):
Expand Down
2 changes: 2 additions & 0 deletions src/nemo_run/run/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def logs(self, runner: Runner, regex: str | None = None):
)

def prepare(self):
self.executor.create_job_dir()
self._executable = package(
self.id, self.task, executor=self.executor, serialize_to_file=True
)
Expand Down Expand Up @@ -306,6 +307,7 @@ def logs(self, runner: Runner, regex: str | None = None):
)

def prepare(self):
self.executor.create_job_dir()
self._executables: list[tuple[AppDef, Executor]] = []
for i, task in enumerate(self.tasks):
executor = self.executors if self._merge else self.executors[i] # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion test/core/execution/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_local_executor_assign():

assert executor.experiment_id == "test_exp"
assert executor.job_dir == os.path.join(tmp_dir, "test_task")
assert os.path.exists(executor.job_dir)
assert not os.path.exists(executor.job_dir)


def test_local_executor_nnodes():
Expand Down
Loading