Skip to content

Add reconciliation process #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions task_processing/plugins/kubernetes/kube_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import os
from http import HTTPStatus
from typing import List
from typing import Optional

from kubernetes import client as kube_client
from kubernetes import config as kube_config
from kubernetes.client.exceptions import ApiException
from kubernetes.client.models.v1_pod import V1Pod

logger = logging.getLogger(__name__)

DEFAULT_ATTEMPTS = 2
Expand Down Expand Up @@ -184,6 +184,12 @@ def get_pod(
pod_name: str,
attempts: int = DEFAULT_ATTEMPTS,
) -> Optional[V1Pod]:
"""
Wrapper around read_namespaced_pod() in the kubernetes clientlib that adds in
retrying on ApiExceptions.

Returns V1Pod on success, None otherwise.
"""
max_attempts = attempts
while attempts:
try:
Expand All @@ -209,6 +215,41 @@ def get_pod(
)
raise
logger.info(f"Ran out of retries attempting to fetch pod {pod_name}.")
raise ExceededMaxAttempts(f'Retried fetching pod {pod_name} {max_attempts} times.')

def get_pods(
self, namespace: str, attempts: int = DEFAULT_ATTEMPTS,
) -> Optional[List[V1Pod]]:
"""
Wrapper around list_namespaced_pod() in the kubernetes clientlib that adds in
retrying on ApiExceptions.

Returns a list of V1Pod on success, None otherwise.
"""
max_attempts = attempts
while attempts:
try:
pods = self.core.list_namespaced_pod(
namespace=namespace,
).items
return pods
except ApiException as e:
# Unknown pods throws ApiException w/ 404
if e.status == 404:
logger.info(f"Found no pods in the namespace {namespace}.")
return None
if not self.maybe_reload_on_exception(exception=e) and attempts:
logger.debug(
f"Failed to fetch pods in {namespace} due to unhandled API exception, "
"retrying.",
exc_info=True
)
attempts -= 1
except Exception:
logger.exception(
f"Failed to fetch pods in {namespace} due to unhandled exception."
)
raise
logger.info(f"Ran out of retries attempting to fetch pods in namespace {namespace}.")
raise ExceededMaxAttempts(
f"Retried fetching pod {pod_name} {max_attempts} times."
)
f'Retried fetching pods in namespace {namespace} {max_attempts} times.')
103 changes: 79 additions & 24 deletions task_processing/plugins/kubernetes/kubernetes_pod_executor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging
import queue
import threading
import time
from queue import Queue
from multiprocessing import cpu_count
from multiprocessing import JoinableQueue
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.pool import Pool
from queue import Empty
from time import sleep
from typing import Collection
from typing import List
from typing import Optional

from kubernetes import watch as kube_watch
Expand Down Expand Up @@ -52,8 +57,8 @@

logger = logging.getLogger(__name__)

POD_WATCH_THREAD_JOIN_TIMEOUT_S = 1.0
POD_EVENT_THREAD_JOIN_TIMEOUT_S = 1.0
POD_WATCH_PROCESS_JOIN_TIMEOUT_S = 1.0
POD_EVENT_PROCESS_JOIN_TIMEOUT_S = 1.0
QUEUE_GET_TIMEOUT_S = 0.5
SUPPORTED_POD_MODIFIED_EVENT_PHASES = {
"Failed",
Expand All @@ -68,6 +73,9 @@
# control plane some breathing room
RETRY_BACKOFF_EXPONENT = 1.5

REFRESH_EXECUTOR_STATE_PROCESS_GRACE = 300
REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL = 120


class KubernetesPodExecutor(TaskExecutor):
TASK_CONFIG_INTERFACE = KubernetesTaskConfig
Expand Down Expand Up @@ -107,7 +115,7 @@ def __init__(
self.stopping = False
self.task_metadata: PMap[str, KubernetesTaskMetadata] = pmap()

self.task_metadata_lock = threading.RLock()
self.task_metadata_lock = Lock()
if task_configs:
for task_config in task_configs:
self._initialize_existing_task(task_config)
Expand All @@ -117,33 +125,38 @@ def __init__(
# and we've opted to not do that processing in the Pod event watcher thread so as to keep
# that logic for the threads that operate on them as simple as possible and to make it
# possible to cleanly shutdown both of these.
self.pending_events: "Queue[PodEvent]" = Queue()
self.event_queue: "Queue[Event]" = Queue()

self.pending_events: "JoinableQueue[PodEvent]" = JoinableQueue()
self.event_queue: "JoinableQueue[Event]" = JoinableQueue()
# TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing
# from where we left off on restarts
self.pod_event_watch_threads = []
self.pod_event_watch_processes: List[Process] = []
self.watches = []
for kube_client in [self.kube_client] + self.watcher_kube_clients:
watch = kube_watch.Watch()
pod_event_watch_thread = threading.Thread(
pod_event_watch_process = Process(
target=self._pod_event_watch_loop,
args=(kube_client, watch),
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# ideally this wouldn't be a daemon process, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
# with the process alive, we did not drop any events
daemon=True,
)
pod_event_watch_thread.start()
self.pod_event_watch_threads.append(pod_event_watch_thread)
pod_event_watch_process.start()
self.pod_event_watch_processes.append(pod_event_watch_process)
self.watches.append(watch)

self.pending_event_processing_thread = threading.Thread(
self.pending_event_processing_process = Process(
target=self._pending_event_processing_loop,
)
self.pending_event_processing_thread.start()
self.pending_event_processing_process.start()

self.reconciliation_task_process = Process(
target=self._reconcile_task_loop,
daemon=True,
)
self.reconciliation_task_process.start()

def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None:
"""Generates task_metadata in UNKNOWN state for an existing KubernetesTaskConfig.
Expand Down Expand Up @@ -468,7 +481,7 @@ def _pending_event_processing_loop(self) -> None:
try:
event = self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S)
self._process_pod_event(event)
except queue.Empty:
except Empty:
logger.debug(
f"Pending event queue remained empty after {QUEUE_GET_TIMEOUT_S} seconds.",
)
Expand All @@ -493,6 +506,46 @@ def _pending_event_processing_loop(self) -> None:

logger.debug("Exiting Pod event processing - stop requested.")

def _reconcile_task_loop(self) -> None:
"""
Run in a thread to reconcile task_metadata from k8s.
"""
logger.info(
f"Waiting {REFRESH_EXECUTOR_STATE_PROCESS_GRACE}s before doing work"
)
sleep(REFRESH_EXECUTOR_STATE_PROCESS_GRACE)
logger.debug("Starting Pod task config reconciliation.")
# allocate half of total cpu count for multiprocessing
num_cpus = cpu_count() // 2 or 1
while not self.stopping:
try:
pods = self.kube_client.get_pods(namespace=self.namespace)
except Exception:
logger.exception(
f"Hit an exception attempting to fetch pods in namespace {self.namespace}"
)
pods = None

if pods is not None:
# returns a list of tuples containing (list[tuple[KubernetesTaskConfig, V1Pod]])
# if the pod is already in task_metadata
task_configs_pods = [
(self.task_metadata[pod.metadata.name].task_config, pod)
for pod in pods
if pod.metadata.name in self.task_metadata
]

# create a process pool that uses half of total cpus
with Pool(num_cpus) as pool:
# call reconcile function for each task_config in parallel
result = pool.starmap_async(self.reconcile, task_configs_pods)
# wait for all tasks to finish
result.wait()
logger.info(f"Sleeping for {REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL}s")
sleep(REFRESH_EXECUTOR_STATE_PROCESS_INTERVAL)

logger.debug("Exiting Pod task config reconciliation - stop requested.")

def _create_container_definition(
self,
name: str,
Expand Down Expand Up @@ -644,7 +697,9 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]:

return None

def reconcile(self, task_config: KubernetesTaskConfig) -> None:
def reconcile(
self, task_config: KubernetesTaskConfig, pod: Optional[V1Pod] = None
) -> None:
pod_name = task_config.pod_name
pod = None
for kube_client in [self.kube_client] + self.watcher_kube_clients:
Expand Down Expand Up @@ -751,8 +806,8 @@ def stop(self) -> None:
# grace period to flush the current event to the pending_events queue as well as
# any other clean-up - it's possible that after this join() the thread is still alive
# but in that case we can be reasonably sure that we're not dropping any data.
for pod_event_watch_thread in self.pod_event_watch_threads:
pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)
for pod_event_watch_process in self.pod_event_watch_processes:
pod_event_watch_process.join(timeout=POD_WATCH_PROCESS_JOIN_TIMEOUT_S)

logger.debug("Waiting for all pending PodEvents to be processed...")
# once we've stopped updating the pending events queue, we then wait until we're done
Expand All @@ -761,11 +816,11 @@ def stop(self) -> None:
self.pending_events.join()
logger.debug("All pending PodEvents have been processed.")
# and then give ourselves time to do any post-stop cleanup
self.pending_event_processing_thread.join(
timeout=POD_EVENT_THREAD_JOIN_TIMEOUT_S
self.pending_event_processing_process.join(
timeout=POD_EVENT_PROCESS_JOIN_TIMEOUT_S
)

logger.debug("Done stopping KubernetesPodExecutor!")

def get_event_queue(self) -> "Queue[Event]":
def get_event_queue(self) -> "JoinableQueue[Event]":
return self.event_queue
7 changes: 7 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import threading

import mock
Expand All @@ -14,3 +15,9 @@ def mock_sleep():
def mock_Thread():
with mock.patch.object(threading, "Thread") as mock_Thread:
yield mock_Thread


@pytest.fixture
def mock_Process():
with mock.patch.object(multiprocessing, 'Process') as mock_Process:
yield mock_Process
19 changes: 19 additions & 0 deletions tests/unit/plugins/kubernetes/kube_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,22 @@ def test_KubeClient_get_pod():
mock_kube_client.CoreV1Api().read_namespaced_pod.assert_called_once_with(
namespace="ns", name="pod-name"
)


def test_KubeClient_get_pods():
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client",
autospec=True
) as mock_kube_client, mock.patch.dict(
os.environ, {"KUBECONFIG": "/another/kube/config.conf"}
):
mock_config_path = "/OVERRIDE.conf"
mock_kube_client.CoreV1Api().list_namespaced_pod.return_value = mock.Mock()
client = KubeClient(kubeconfig_path=mock_config_path)
client.get_pods(namespace='ns', attempts=1)
mock_kube_client.CoreV1Api().list_namespaced_pod.assert_called_once_with(
namespace='ns'
)
4 changes: 2 additions & 2 deletions tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


@pytest.fixture
def k8s_executor(mock_Thread):
def k8s_executor(mock_Process):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
Expand Down Expand Up @@ -90,7 +90,7 @@ def mock_task_configs():


@pytest.fixture
def k8s_executor_with_tasks(mock_Thread, mock_task_configs):
def k8s_executor_with_tasks(mock_Process, mock_task_configs):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
Expand Down
Loading