From 287b4cd93dfd380382735b230bb3dcc9c951d79a Mon Sep 17 00:00:00 2001 From: Dan Cardin Date: Wed, 22 Apr 2020 17:38:06 -0400 Subject: [PATCH] "Rebase" docker operator onto current airflow version. --- src/airflow_docker/operator.py | 255 ++++++++------------------------- 1 file changed, 57 insertions(+), 198 deletions(-) diff --git a/src/airflow_docker/operator.py b/src/airflow_docker/operator.py index 383602b..4495029 100644 --- a/src/airflow_docker/operator.py +++ b/src/airflow_docker/operator.py @@ -36,22 +36,24 @@ # under the License. import ast import json +from tempfile import TemporaryDirectory +from typing import Dict, Iterable, List, Optional, Union import airflow.configuration as conf import airflow_docker_helper -import six from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.hooks.docker_hook import DockerHook -from airflow.models import BaseOperator, SkipMixin +from airflow.models import SkipMixin +from airflow.providers.docker.hooks.docker import DockerHook +from airflow.providers.docker.operators.docker import ( + DockerOperator as AirflowDockerOperator, +) from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults -from airflow.utils.file import TemporaryDirectory +from docker import APIClient, tls + from airflow_docker.conf import get_boolean_default, get_default from airflow_docker.ext import delegate_to_extensions, register_extensions from airflow_docker.utils import get_config -from docker import APIClient, tls - -DEFAULT_HOST_TEMPORARY_DIRECTORY = "/tmp/airflow" class ShortCircuitMixin(SkipMixin): @@ -79,7 +81,7 @@ def execute(self, context): class BranchMixin(SkipMixin): def execute(self, context): branch = super(BranchMixin, self).execute(context) - if isinstance(branch, six.string_types): + if isinstance(branch, str): branch = [branch] self.log.info("Following branch %s", branch) self.log.info("Marking other directly downstream tasks as skipped") @@ -107,7 +109,7 @@ def execute(self, context): @register_extensions -class BaseDockerOperator(object): +class DockerOperator(AirflowDockerOperator): """ Execute a command inside a docker container. @@ -127,31 +129,31 @@ class BaseDockerOperator(object): :param api_version: Remote API version. Set to ``auto`` to automatically detect the server's version. :type api_version: str - :param auto_remove: Auto-removal of the container on daemon side when the - container's process exits. - The default is True. - :type auto_remove: bool :param command: Command to be run in the container. (templated) :type command: str or list + :param container_name: Name of the container. Optional (templated) + :type container_name: str or None :param cpus: Number of CPUs to assign to the container. This value gets multiplied with 1024. See https://docs.docker.com/engine/reference/run/#cpu-share-constraint :type cpus: float - :param dns: Docker custom DNS servers - :type dns: list of strings - :param dns_search: Docker custom DNS search domain - :type dns_search: list of strings :param docker_url: URL of the host running the docker daemon. Default is unix://var/run/docker.sock :type docker_url: str :param environment: Environment variables to set in the container. (templated) :type environment: dict - :param force_pull: Pull the docker image on every run. Default is True. + :param private_environment: Private environment variables to set in the container. + These are not templated, and hidden from the website. + :type private_environment: dict + :param force_pull: Pull the docker image on every run. Default is False. :type force_pull: bool :param mem_limit: Maximum amount of memory the container can use. Either a float value, which represents the limit in bytes, or a string like ``128m`` or ``1g``. :type mem_limit: float or str + :param host_tmp_dir: Specify the location of the temporary directory on the host which will + be mapped to tmp_dir. If not provided defaults to using the standard system temp directory. + :type host_tmp_dir: str :param network_mode: Network mode for the container. :type network_mode: str :param tls_ca_cert: Path to a PEM-encoded certificate authority @@ -176,20 +178,30 @@ class BaseDockerOperator(object): :type user: int or str :param volumes: List of volumes to mount into the container, e.g. ``['/host/path:/container/path', '/host/path2:/container/path2:ro']``. + :type volumes: list :param working_dir: Working directory to set on the container (equivalent to the -w switch the docker client) :type working_dir: str - :param xcom_push: Does the stdout will be pushed to the next step using XCom. - The default is False. - :type xcom_push: bool :param xcom_all: Push all the stdout or just the last line. The default is False (last line). :type xcom_all: bool :param docker_conn_id: ID of the Airflow connection to use :type docker_conn_id: str + :param dns: Docker custom DNS servers + :type dns: list[str] + :param dns_search: Docker custom DNS search domain + :type dns_search: list[str] + :param auto_remove: Auto-removal of the container on daemon side when the + container's process exits. + The default is False. + :type auto_remove: bool :param shm_size: Size of ``/dev/shm`` in bytes. The size must be greater than 0. If omitted uses system default. :type shm_size: int + :param tty: Allocate pseudo-TTY to the container + This needs to be set see logs of the Docker container. + :type tty: bool + :param provide_context: If True, make a serialized form of the context available. :type provide_context: bool @@ -197,44 +209,20 @@ class BaseDockerOperator(object): If omitted defaults to the "default" key, see `EnvironmentPresetExtension`. :type environment_preset: string """ - - template_fields = ("command", "environment", "extra_kwargs") - template_ext = (".sh", ".bash") + template_fields = ('command', 'environment', 'container_name', "extra_kwargs") known_extra_kwargs = set() @apply_defaults def __init__( - self, - image, - api_version=None, - entrypoint=None, - command=None, - cpus=1.0, - docker_url="unix://var/run/docker.sock", - environment=None, - force_pull=get_boolean_default("force_pull", True), - mem_limit=None, - network_mode=get_default("network_mode", None), - tls_ca_cert=None, - tls_client_cert=None, - tls_client_key=None, - tls_hostname=None, - tls_ssl_version=None, - tmp_dir="/tmp/airflow", - user=None, - volumes=None, - working_dir=None, - xcom_push=False, - xcom_all=False, - docker_conn_id=None, - dns=None, - dns_search=None, - auto_remove=get_boolean_default("auto_remove", True), - shm_size=None, - provide_context=False, - *args, - **kwargs - ): + self, + image: str, + force_pull: bool = get_boolean_default("force_pull", True), + network_mode: Optional[str] = get_default("network_mode", None), + auto_remove: bool = get_boolean_default("auto_remove", True), + provide_context=False, + *args, + **kwargs) -> None: + self.extra_kwargs = { known_key: kwargs.pop(known_key) for known_key in self.known_extra_kwargs @@ -243,148 +231,26 @@ def __init__( if known_key in kwargs } - super(BaseDockerOperator, self).__init__(*args, **kwargs) - self.api_version = api_version - self.auto_remove = auto_remove - self.command = command - self.entrypoint = entrypoint - self.cpus = cpus - self.dns = dns - self.dns_search = dns_search - self.docker_url = docker_url - self.environment = environment or {} - self.force_pull = force_pull - self.image = image - self.mem_limit = mem_limit - self.network_mode = network_mode - self.tls_ca_cert = tls_ca_cert - self.tls_client_cert = tls_client_cert - self.tls_client_key = tls_client_key - self.tls_hostname = tls_hostname - self.tls_ssl_version = tls_ssl_version - self.tmp_dir = tmp_dir - self.user = user - self.volumes = volumes or [] - self.working_dir = working_dir - self.xcom_push_flag = xcom_push - self.xcom_all = xcom_all - self.docker_conn_id = docker_conn_id - self.shm_size = shm_size - self.provide_context = provide_context - - self.cli = None - self.container = None - self._host_client = None # Shim for attaching a test client + super().__init__(*args, force_pull=force_pull, network_mode=network_mode, auto_remove=auto_remove, **kwargs) - def get_hook(self): - return DockerHook( - docker_conn_id=self.docker_conn_id, - base_url=self.docker_url, - version=self.api_version, - tls=self.__get_tls_config(), - ) + self._host_client = None # Shim for attaching a test client - def _execute(self, context): - self.log.info("Starting docker container from image %s", self.image) + def execute(self, context): + # Hook for creating mounted meta directories + self.prepare_host_tmp_dir(context, self.host_tmp_dir) + self.prepare_environment(context, self.host_tmp_dir) - tls_config = self.__get_tls_config() + if self.provide_context: + self.write_context(context, self.host_tmp_dir) - if self.docker_conn_id: - self.cli = self.get_hook().get_conn() - else: - self.cli = APIClient( - base_url=self.docker_url, version=self.api_version, tls=tls_config - ) + super().execute(context) - if self.force_pull or len(self.cli.images(name=self.image)) == 0: - self.log.info("Pulling docker image %s", self.image) - for l in self.cli.pull(self.image, stream=True): - output = json.loads(l.decode("utf-8").strip()) - if "status" in output: - self.log.info("%s", output["status"]) - - with TemporaryDirectory( - prefix="airflowtmp", dir=self.host_tmp_base_dir - ) as host_tmp_dir: - self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir - additional_volumes = ["{0}:{1}".format(host_tmp_dir, self.tmp_dir)] - - # Hook for creating mounted meta directories - self.prepare_host_tmp_dir(context, host_tmp_dir) - self.prepare_environment(context, host_tmp_dir) - - if self.provide_context: - self.write_context(context, host_tmp_dir) - - self.container = self.cli.create_container( - command=self.get_command(), - entrypoint=self.entrypoint, - environment=self.environment, - host_config=self.cli.create_host_config( - auto_remove=self.auto_remove, - binds=self.volumes + additional_volumes, - network_mode=self.network_mode, - shm_size=self.shm_size, - dns=self.dns, - dns_search=self.dns_search, - cpu_shares=int(round(self.cpus * 1024)), - mem_limit=self.mem_limit, - ), - image=self.image, - user=self.user, - working_dir=self.working_dir, - ) - self.cli.start(self.container["Id"]) - - line = "" - for line in self.cli.logs(container=self.container["Id"], stream=True): - line = line.strip() - if hasattr(line, "decode"): - line = line.decode("utf-8") - self.log.info(line) - - result = self.cli.wait(self.container["Id"]) - if result["StatusCode"] != 0: - raise AirflowException("docker container failed: " + repr(result)) - - # Move the in-container xcom-pushes into airflow. - result = self.host_client.get_xcom_push_data(host_tmp_dir) - for row in result: - self.xcom_push(context, key=row["key"], value=row["value"]) - - if self.xcom_push_flag: - return ( - self.cli.logs(container=self.container["Id"]) - if self.xcom_all - else str(line) - ) + # Move the in-container xcom-pushes into airflow. + result = self.host_client.get_xcom_push_data(self.host_tmp_dir) + for row in result: + self.xcom_push(context, key=row["key"], value=row["value"]) - return self.do_meta_operation(context, host_tmp_dir) - - def get_command(self): - if self.command is not None and self.command.strip().find("[") == 0: - commands = ast.literal_eval(self.command) - else: - commands = self.command - return commands - - def on_kill(self): - if self.cli is not None: - self.log.info("Stopping docker container") - self.cli.stop(self.container["Id"]) - - def __get_tls_config(self): - tls_config = None - if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: - tls_config = tls.TLSConfig( - ca_cert=self.tls_ca_cert, - client_cert=(self.tls_client_cert, self.tls_client_key), - verify=True, - ssl_version=self.tls_ssl_version, - assert_hostname=self.tls_hostname, - ) - self.docker_url = self.docker_url.replace("tcp://", "https://") - return tls_config + return self.do_meta_operation(context, self.host_tmp_dir) def do_meta_operation(self, context, host_tmp_dir): pass @@ -400,13 +266,6 @@ def prepare_host_tmp_dir(self, context, host_tmp_dir): def write_context(self, context, host_tmp_dir): self.host_client.write_context(context, host_tmp_dir) - @property - def host_tmp_base_dir(self): - try: - return conf.get("worker", "host_temporary_directory") - except AirflowConfigException: - return DEFAULT_HOST_TEMPORARY_DIRECTORY - def host_meta_dir(self, context, host_tmp_dir): return airflow_docker_helper.get_host_meta_path(host_tmp_dir)