Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Publish Job Event from Bastion and GKE Runner #706

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 106 additions & 10 deletions axlearn/cloud/common/bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
logging.warning("tensorflow_io is not installed -- tf_io may not work with s3://")

from axlearn.cloud.common.cleaner import Cleaner
from axlearn.cloud.common.event_queue import BaseQueueClient, Event
from axlearn.cloud.common.quota import QuotaFn
from axlearn.cloud.common.scheduler import BaseScheduler, JobMetadata, JobScheduler, ResourceMap
from axlearn.cloud.common.types import JobSpec
Expand All @@ -88,6 +89,7 @@
Required,
config_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.utils import Nested

Expand Down Expand Up @@ -153,6 +155,64 @@ class JobStatus(str, enum.Enum):
COMPLETED = "COMPLETED"


# Subclass str to be JSON serializable: https://stackoverflow.com/a/51976841
class JobLifecycleState(str, enum.Enum):
"""Represents a lifecycle state for a job.

The lifecycle state is meant for fine-grained reporting and tracking of job state transitions.
For the states corresponding to bastion's internal state machine, see `JobStatus`.
"""

# Job is queued. Bastion detects the new job's jobspec.
QUEUED = "QUEUED"
# Job is starting. Command is start to run.
STARTING = "STARTING"
# Job is running.
RUNNING = "RUNNING"
# Job is pre-empting.
PREEMPTING = "PREEMPTING"
# Job is rescheduling.
RESCHEDULING = "RESCHEDULING"
# Job is cancelling. Command is terminating.
CANCELLING = "CANCELLING"
# Job has completed/terminated the command, is running cleanup command (if any).
CLEANING = "CLEANING"
# Job is failed.
FAILED = "FAILED"
# Job finished successfully.
SUCCEEDED = "SUCCEEDED"
# Job is complete.
COMPLETED = "COMPLETED"


@dataclasses.dataclass
class JobLifecycleEvent(Event):
"""Represents a lifecycle event for a job.

Attributes:
job_name: The name of the job associated with this event.
state: The state of the job.
details: The details of the state info.
job_id: An optional identifier for the job. Defaults to None.
"""

job_name: str
state: JobLifecycleState
details: str
job_id: Optional[str] = None

def serialize(self) -> str:
"""Serializes the job lifecycle event into a JSON string."""
job_event = {
"job_name": self.job_name,
"job_id": self.job_id,
"message": self.details,
"state": self.state,
"timestamp": time.time_ns(),
}
return json.dumps(job_event)


class ValidationError(ValueError):
"""Validation failure (e.g. JobSpec deserialization)."""

Expand Down Expand Up @@ -605,6 +665,8 @@ class Config(Configurable.Config):
output_dir: Required[str] = REQUIRED
# The quota function to use for getting user group membership and group quota.
quota: Required[QuotaFn] = REQUIRED
# The event publisher sends events into queue.
event_publisher: Optional[BaseQueueClient.Config] = None

def __init__(self, cfg: Config):
super().__init__(cfg)
Expand Down Expand Up @@ -646,11 +708,22 @@ def __init__(self, cfg: Config):
).instantiate()
self._cleaner: Cleaner = cfg.cleaner.instantiate()
self._uploader = cfg.uploader.set(src_dir=_LOG_DIR, dst_dir=self._log_dir).instantiate()
self._event_publisher = maybe_instantiate(cfg.event_publisher)

def _append_to_job_history(self, job: Job, msg: str):
def _append_to_job_history(self, job: Job, *, msg: str, state: JobLifecycleState):
with tf_io.gfile.GFile(os.path.join(self._job_history_dir, f"{job.spec.name}"), "a") as f:
curr_time = datetime.now(timezone.utc).strftime("%m%d %H:%M:%S")
f.write(f"{curr_time} {msg}\n")
# Publish event into queue.
if self._event_publisher:
self._event_publisher.publish(
JobLifecycleEvent(
job_name=job.spec.name,
job_id=job.spec.metadata.job_id,
state=state,
details=msg,
)
)

def _append_to_project_history(
self, jobs: dict[str, JobMetadata], schedule_results: BaseScheduler.ScheduleResults
Expand Down Expand Up @@ -784,7 +857,13 @@ def _sync_jobs(self):
if job_name not in self._active_jobs:
logging.info("Detected new job %s.", job_name)
self._active_jobs[job_name] = active_jobs[job_name]
self._append_to_job_history(active_jobs[job_name], "PENDING: detected jobspec")
self._append_to_job_history(
active_jobs[job_name],
msg="PENDING: detected jobspec",
# When Bastion restarts, we will see this for every job.
# Leave to consumer to handle this case.
state=JobLifecycleState.QUEUED,
)
# Detected removed job: exists locally, but not in remote.
elif job_name not in active_jobs:
job = self._active_jobs[job_name]
Expand Down Expand Up @@ -826,7 +905,9 @@ def _update_single_job(self, job: Job) -> Job:
# 2. Any job logs are sync'ed to remote log dir. The local log file cannot reliably be
# expected to be present if/when the job is resumed.
if job.command_proc is not None:
self._append_to_job_history(job, "PENDING: pre-empting")
self._append_to_job_history(
job, msg="PENDING: pre-empting", state=JobLifecycleState.PREEMPTING
)
logging.info("Pre-empting job: %s", job.spec.name)
self._wait_and_close_proc(job.command_proc, kill=True)
job.command_proc = None
Expand All @@ -838,8 +919,9 @@ def _update_single_job(self, job: Job) -> Job:
if job.command_proc is None:
self._append_to_job_history(
job,
f"ACTIVE: start process command: {job.spec.command} "
msg=f"ACTIVE: start process command: {job.spec.command} "
f"with metadata: {job.state.metadata}",
state=JobLifecycleState.STARTING,
)
env_vars = {f"BASTION_{k.upper()}": v for k, v in job.state.metadata.items()}
serialized_jobspec = io.StringIO()
Expand All @@ -854,7 +936,9 @@ def _update_single_job(self, job: Job) -> Job:

# If command is completed, move to CLEANING. Otherwise, it's still RUNNING.
if _is_proc_complete(job.command_proc):
self._append_to_job_history(job, "CLEANING: process finished")
self._append_to_job_history(
job, msg="CLEANING: process finished", state=JobLifecycleState.CLEANING
)
logging.info(
"Job %s stopped gracefully: %s.",
job.spec.name,
Expand All @@ -866,11 +950,17 @@ def _update_single_job(self, job: Job) -> Job:
# If job is still running, terminate it. We stay in CANCELLING until it has fully
# exited, after which we move to CLEANING.
if job.command_proc is not None and not _is_proc_complete(job.command_proc):
self._append_to_job_history(job, "CANCELLING: terminating the process")
self._append_to_job_history(
job,
msg="CANCELLING: terminating the process",
state=JobLifecycleState.CANCELLING,
)
logging.info("Sending SIGTERM to job: %s", job.spec.name)
job.command_proc.popen.terminate()
else:
self._append_to_job_history(job, "CLEANING: process terminated")
self._append_to_job_history(
job, msg="CLEANING: process terminated", state=JobLifecycleState.CLEANING
)
job.state.status = JobStatus.CLEANING

elif job.state.status == JobStatus.CLEANING:
Expand All @@ -886,14 +976,18 @@ def _update_single_job(self, job: Job) -> Job:
# every time, in case bastion got pre-empted.
if job.spec.cleanup_command and not job.cleanup_proc:
self._append_to_job_history(
job, f"CLEANING: start cleanup command: {job.spec.cleanup_command}"
job,
msg=f"CLEANING: start cleanup command: {job.spec.cleanup_command}",
state=JobLifecycleState.CLEANING,
)
_start_cleanup_command(job)

# If job has no cleanup command, or cleanup command is complete, transition to
# COMPLETED.
if job.cleanup_proc is None or _is_proc_complete(job.cleanup_proc):
self._append_to_job_history(job, "COMPLETED: cleanup finished")
self._append_to_job_history(
job, msg="COMPLETED: cleanup finished", state=JobLifecycleState.COMPLETED
)
logging.info("Job %s finished running cleanup.", job.spec.name)
if job.cleanup_proc is not None:
self._wait_and_close_proc(job.cleanup_proc)
Expand Down Expand Up @@ -973,7 +1067,9 @@ def _update_jobs(self):
# see whether this is necessary.
assert job.state.status == JobStatus.ACTIVE and changed_tiers
self._append_to_job_history(
job, f"Rescheduling at a different tier from {old_tier} to {new_tier}"
job,
msg=f"Rescheduling at a different tier from {old_tier} to {new_tier}",
state=JobLifecycleState.RESCHEDULING,
)
job.state.status = JobStatus.PENDING
else:
Expand Down
77 changes: 77 additions & 0 deletions axlearn/cloud/common/bastion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Bastion,
BastionDirectory,
Job,
JobLifecycleEvent,
JobLifecycleState,
JobState,
JobStatus,
ValidationError,
Expand Down Expand Up @@ -557,6 +559,30 @@ def test_download_compat(self):
)


class TestJobLifecycleEvent(parameterized.TestCase):
"""Tests for JobLifecycleEvent."""

def test_serialize(self):
"""Test serialization of JobLifecycleEvent."""
job_event = JobLifecycleEvent(
job_name="test_job",
state=JobLifecycleState.RUNNING.value,
job_id="12345",
details="test_details",
)
with mock.patch("time.time_ns", return_value=1234567890123456789):
expected_output = json.dumps(
{
"job_name": "test_job",
"job_id": "12345",
"message": "test_details",
"state": "RUNNING",
"timestamp": 1234567890123456789,
}
)
self.assertEqual(job_event.serialize(), expected_output)


class TestRuntimeOptions(parameterized.TestCase):
"""Tests runtime options."""

Expand Down Expand Up @@ -668,6 +694,57 @@ def noop_upload_fn(*args, **kwargs):
)
yield cfg.instantiate()

@parameterized.parameters(
[
dict(
popen_spec={
"command": {
"wait.return_value": None,
"poll.side_effect": ValueError,
"terminate.side_effect": ValueError,
},
"cleanup": {
"poll.side_effect": ValueError,
"terminate.side_effect": ValueError,
},
},
),
],
)
def test_append_to_job_history_event_publish(self, popen_spec):
"""Test event publishing."""
mock_proc = _mock_piped_popen_fn(popen_spec)
job = Job(
spec=new_jobspec(
name="test_job",
command="command",
cleanup_command="cleanup",
metadata=JobMetadata(
user_id="test_user",
project_id="test_project",
creation_time=datetime.now(),
resources={"v4": 8},
),
),
state=JobState(status=JobStatus.PENDING),
command_proc=mock_proc("command", "test_command") if "command" in popen_spec else None,
cleanup_proc=mock_proc("cleanup", "test_cleanup") if "cleanup" in popen_spec else None,
)
mock_event_publisher = mock.MagicMock()
with (
self._patch_bastion(popen_spec) as mock_bastion,
mock.patch.object(mock_bastion, "_event_publisher", mock_event_publisher),
):
mock_bastion._append_to_job_history(
job, msg="Job is starting", state=JobLifecycleState.STARTING
)
mock_event_publisher.publish.assert_called()
mock_event_publisher.publish.assert_called_once_with(
JobLifecycleEvent(
job_name="test_job", state=JobLifecycleState.STARTING, details="Job is starting"
)
)

def test_sync_jobs(self):
"""Tests downloading jobspecs."""

Expand Down
Loading