Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 49 additions & 74 deletions nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# limitations under the License.

import base64
import glob
import json
import logging
import os
import queue
import subprocess
import tempfile
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -323,7 +322,8 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
launch_script = f"""
ln -s {self.pvc_job_dir}/ /nemo_run
cd /nemo_run/code
{" ".join(cmd)}
mkdir -p {self.pvc_job_dir}/logs
{" ".join(cmd)} 2>&1 | tee -a {self.pvc_job_dir}/logs/output-$HOSTNAME.log
"""
with open(os.path.join(self.job_dir, "launch_script.sh"), "w+") as f:
f.write(launch_script)
Expand Down Expand Up @@ -371,91 +371,66 @@ def status(self, job_id: str) -> Optional[DGXCloudState]:
r_json = response.json()
return DGXCloudState(r_json["phase"])

def _stream_url_sync(self, url: str, headers: dict, q: queue.Queue):
"""Stream a single URL using requests and put chunks into the queue"""
try:
with requests.get(url, stream=True, headers=headers, verify=False) as response:
for line in response.iter_lines(decode_unicode=True):
q.put((url, f"{line}\n"))
except Exception as e:
logger.error(f"Error streaming URL {url}: {e}")

finally:
q.put((url, None))

def fetch_logs(
self,
job_id: str,
stream: bool,
stderr: Optional[bool] = None,
stdout: Optional[bool] = None,
) -> Iterable[str]:
token = self.get_auth_token()
if not token:
logger.error("Failed to retrieve auth token for fetch logs request.")
yield ""

response = requests.get(
f"{self.base_url}/workloads", headers=self._default_headers(token=token)
)
workload_name = next(
(
workload["name"]
for workload in response.json()["workloads"]
if workload["id"] == job_id
),
None,
)
if workload_name is None:
logger.error(f"No workload found with id {job_id}")
yield ""
while self.status(job_id) != DGXCloudState.RUNNING:
logger.info("Waiting for job to start...")
time.sleep(15)

urls = [
f"{self.kube_apiserver_url}/api/v1/namespaces/runai-{self.project_name}/pods/{workload_name}-worker-{i}/log?container=pytorch"
for i in range(self.nodes)
]
cmd = ["tail"]

if stream:
urls = [url + "&follow=true" for url in urls]
cmd.append("-f")

while self.status(job_id) != DGXCloudState.RUNNING:
logger.info("Waiting for job to start...")
time.sleep(15)
# setting linked PVC job directory
nemo_run_home = get_nemorun_home()
job_subdir = self.job_dir[len(nemo_run_home) + 1 :] # +1 to remove the initial backslash
self.pvc_job_dir = os.path.join(self.pvc_nemo_run_dir, job_subdir)

time.sleep(10)
files = []
while len(files) < self.nodes:
files = list(glob.glob(f"{self.pvc_job_dir}/logs/output-*.log"))
logger.info(f"Waiting for {self.nodes - len(files)} log files to be created...")
time.sleep(3)

q = queue.Queue()
active_urls = set(urls)
cmd.extend(files)

# Start threads
threads = [
threading.Thread(
target=self._stream_url_sync, args=(url, self._default_headers(token=token), q)
)
for url in urls
]
for t in threads:
t.start()

# Yield chunks as they arrive
while active_urls:
url, item = q.get()
if item is None or self.status(job_id) in [
DGXCloudState.DELETING,
DGXCloudState.STOPPED,
DGXCloudState.STOPPING,
DGXCloudState.DEGRADED,
DGXCloudState.FAILED,
DGXCloudState.COMPLETED,
DGXCloudState.TERMINATING,
]:
active_urls.discard(url)
else:
yield item

# Wait for threads
for t in threads:
t.join()
logger.info(f"Attempting to stream logs with command: {cmd}")

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1)

if stream:
while True:
try:
for line in iter(proc.stdout.readline, ""):
if (
line
and not line.rstrip("\n").endswith(".log <==")
and line.rstrip("\n") != ""
):
yield f"{line}"
if proc.poll() is not None:
break
except Exception as e:
logger.error(f"Error streaming logs: {e}")
time.sleep(3)
continue

else:
try:
for line in iter(proc.stdout.readline, ""):
if line:
yield line.rstrip("\n")
if proc.poll() is not None:
break
finally:
proc.terminate()
proc.wait(timeout=2)

def cancel(self, job_id: str):
# Retrieve the authentication token for the REST calls
Expand Down
Loading
Loading