diff --git a/nemo_run/core/execution/lepton.py b/nemo_run/core/execution/lepton.py index f3cd2c92..7c1fe83f 100644 --- a/nemo_run/core/execution/lepton.py +++ b/nemo_run/core/execution/lepton.py @@ -82,7 +82,13 @@ def copy_directory_data_command(self, local_dir_path: str, dest_path: str) -> Li full_command = ["sh", "-c", cmd] return full_command - def move_data(self, sleep: float = 10, timeout: int = 600, poll_interval: int = 5) -> None: + def move_data( + self, + sleep: float = 10, + timeout: int = 600, + poll_interval: int = 5, + unknowns_grace_period: int = 60, + ) -> None: """ Moves job directory into remote storage and deletes the workload after completion. """ @@ -121,20 +127,39 @@ def move_data(self, sleep: float = 10, timeout: int = 600, poll_interval: int = job_id = response.metadata.id_ start_time = time.time() - count = 0 while True: if time.time() - start_time > timeout: raise TimeoutError(f"Job {job_id} did not complete within {timeout} seconds.") + current_job = client.job.get(job_id) current_job_status = current_job.status.state - if count > 0 and current_job_status in [ - LeptonJobState.Completed, - LeptonJobState.Failed, - LeptonJobState.Unknown, - ]: + if ( + current_job_status == LeptonJobState.Completed + or current_job_status == LeptonJobState.Failed + ): break - count += 1 + elif current_job_status == LeptonJobState.Unknown: + logging.warning( + f"Job {job_id} entered Unknown state, checking for up to {unknowns_grace_period} seconds every 2 seconds..." + ) + unknown_start_time = time.time() + recovered = False + while time.time() - unknown_start_time < unknowns_grace_period: + time.sleep(2) + current_job = client.job.get(job_id) + current_job_status = current_job.status.state + if current_job_status != LeptonJobState.Unknown: + logging.info( + f"Job {job_id} recovered from Unknown state to {current_job_status}" + ) + recovered = True + break + if not recovered: + logging.error( + f"Job {job_id} has been in Unknown state for more than {unknowns_grace_period} seconds" + ) + break time.sleep(poll_interval) if current_job_status != LeptonJobState.Completed: