Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions nemo_run/run/ray/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def _status(
result = executor.tunnel.run(cmd)

job_id = result.stdout.strip()
job_id = job_id.split("\n")[-1]

# If job not found in running jobs, check if it's in cluster_map
if not job_id:
Expand Down Expand Up @@ -664,7 +665,11 @@ def run(self):
]
)

jump_arg_str = f"{executor.tunnel.user}@{executor.tunnel.host}"
jump_arg_str = (
f"{executor.tunnel.user}@{executor.tunnel.host}"
if isinstance(executor.tunnel, SSHTunnel)
else None
)
raw_jump_identity = getattr(executor.tunnel, "identity", None)
jump_identity_path_for_proxy = None
if raw_jump_identity:
Expand Down Expand Up @@ -1106,13 +1111,14 @@ def start(
# ------------------------------------------------------------------
# Ship *workdir* over to the remote side (or package via packager)
# ------------------------------------------------------------------
cluster_dir = os.path.join(self.executor.tunnel.job_dir, self.name)
remote_workdir: Optional[str] = None

if workdir:
if isinstance(self.executor.tunnel, SSHTunnel):
# Rsync workdir honouring .gitignore
remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code")
if not dryrun:
remote_workdir = os.path.join(cluster_dir, "code")
if not dryrun:
if isinstance(self.executor.tunnel, SSHTunnel):
# Rsync workdir honouring .gitignore
self.executor.tunnel.connect()
assert self.executor.tunnel.session is not None, (
"Tunnel session is not connected"
Expand All @@ -1123,11 +1129,24 @@ def start(
remote_workdir,
rsync_opts="--filter=':- .gitignore'",
)
else:
remote_workdir = workdir
else:
os.makedirs(remote_workdir, exist_ok=True)
subprocess.run(
[
"rsync",
"-pthrvz",
"--filter=:- .gitignore",
f"{os.path.join(workdir, '')}",
remote_workdir,
],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
elif self.executor.packager is not None:
# Use the packager to create an archive which we then extract on the
# submission host and optionally rsync to the target.
remote_workdir = os.path.join(cluster_dir, "code")
if not dryrun:
if isinstance(self.executor.tunnel, SSHTunnel):
package_dir = tempfile.mkdtemp(prefix="nemo_packager_")
Expand Down Expand Up @@ -1157,7 +1176,6 @@ def start(
)

if isinstance(self.executor.tunnel, SSHTunnel):
remote_workdir = os.path.join(self.executor.tunnel.job_dir, self.name, "code")
self.executor.tunnel.connect()
assert self.executor.tunnel.session is not None, (
"Tunnel session is not connected"
Expand All @@ -1169,7 +1187,19 @@ def start(
rsync_opts="--filter=':- .gitignore'",
)
else:
remote_workdir = local_code_extraction_path
os.makedirs(remote_workdir, exist_ok=True)
subprocess.run(
[
"rsync",
"-pthrvz",
"--filter=:- .gitignore",
f"{os.path.join(local_code_extraction_path, '')}",
remote_workdir,
],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)

assert remote_workdir is not None, "workdir could not be determined"

Expand Down
Loading
Loading