diff --git a/airflow/providers/cncf/kubernetes/CHANGELOG.rst b/airflow/providers/cncf/kubernetes/CHANGELOG.rst index 514035c8cac04..bb2d236594044 100644 --- a/airflow/providers/cncf/kubernetes/CHANGELOG.rst +++ b/airflow/providers/cncf/kubernetes/CHANGELOG.rst @@ -19,6 +19,19 @@ Changelog --------- +main +.... + +Features +~~~~~~~~ + +KubernetesPodOperator now uses KubernetesHook +````````````````````````````````````````````` + +Previously, KubernetesPodOperator relied on core Airflow configuration (namely setting for kubernetes executor) for certain settings used in client generation. Now KubernetesPodOperator uses KubernetesHook, and the consideration of core k8s settings is officially deprecated. + +If you are using the Airflow configuration settings (e.g. as opposed to operator params) to configure the kubernetes client, then prior to the next major release you will need to add an Airflow connection and set your KPO tasks to use that connection. + 4.0.2 ..... diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 5719918ce7b91..e15dce67ef40a 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -16,10 +16,13 @@ # under the License. import sys import tempfile -from typing import Any, Dict, Generator, Optional, Tuple, Union +import warnings +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from kubernetes.config import ConfigException +from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive + if sys.version_info >= (3, 8): from functools import cached_property else: @@ -63,6 +66,14 @@ class KubernetesHook(BaseHook): :param conn_id: The :ref:`kubernetes connection ` to Kubernetes cluster. + :param client_configuration: Optional dictionary of client configuration params. + Passed on to kubernetes client. + :param cluster_context: Optionally specify a context to use (e.g. if you have multiple + in your kubeconfig. + :param config_file: Path to kubeconfig file. + :param in_cluster: Set to ``True`` if running from within a kubernetes cluster. + :param disable_verify_ssl: Set to ``True`` if SSL verification should be disabled. + :param disable_tcp_keepalive: Set to ``True`` if you want to disable keepalive logic. """ conn_name_attr = 'kubernetes_conn_id' @@ -91,6 +102,8 @@ def get_connection_form_widgets() -> Dict[str, Any]: "extra__kubernetes__cluster_context": StringField( lazy_gettext('Cluster context'), widget=BS3TextFieldWidget() ), + "extra__kubernetes__disable_verify_ssl": BooleanField(lazy_gettext('Disable SSL')), + "extra__kubernetes__disable_tcp_keepalive": BooleanField(lazy_gettext('Disable TCP keepalive')), } @staticmethod @@ -108,6 +121,8 @@ def __init__( cluster_context: Optional[str] = None, config_file: Optional[str] = None, in_cluster: Optional[bool] = None, + disable_verify_ssl: Optional[bool] = None, + disable_tcp_keepalive: Optional[bool] = None, ) -> None: super().__init__() self.conn_id = conn_id @@ -115,6 +130,16 @@ def __init__( self.cluster_context = cluster_context self.config_file = config_file self.in_cluster = in_cluster + self.disable_verify_ssl = disable_verify_ssl + self.disable_tcp_keepalive = disable_tcp_keepalive + + # these params used for transition in KPO to K8s hook + # for a deprecation period we will continue to consider k8s settings from airflow.cfg + self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None + self._deprecated_core_disable_verify_ssl: Optional[bool] = None + self._deprecated_core_in_cluster: Optional[bool] = None + self._deprecated_core_cluster_context: Optional[str] = None + self._deprecated_core_config_file: Optional[str] = None @staticmethod def _coalesce_param(*params): @@ -122,23 +147,51 @@ def _coalesce_param(*params): if param is not None: return param - def get_conn(self) -> Any: - """Returns kubernetes api session for use with requests""" + @cached_property + def conn_extras(self): if self.conn_id: connection = self.get_connection(self.conn_id) extras = connection.extra_dejson else: extras = {} + return extras + + def _get_field(self, field_name): + if field_name.startswith('extra_'): + raise ValueError( + f"Got prefixed name {field_name}; please remove the 'extra__kubernetes__' prefix " + f"when using this method." + ) + if field_name in self.conn_extras: + return self.conn_extras[field_name] or None + prefixed_name = f"extra__kubernetes__{field_name}" + return self.conn_extras.get(prefixed_name) or None + + @staticmethod + def _deprecation_warning_core_param(deprecation_warnings): + settings_list_str = ''.join([f"\n\t{k}={v!r}" for k, v in deprecation_warnings]) + warnings.warn( + f"\nApplying core Airflow settings from section [kubernetes] with the following keys:" + f"{settings_list_str}\n" + "In a future release, KubernetesPodOperator will no longer consider core\n" + "Airflow settings; define an Airflow connection instead.", + DeprecationWarning, + ) + + def get_conn(self) -> Any: + """Returns kubernetes api session for use with requests""" + in_cluster = self._coalesce_param( - self.in_cluster, extras.get("extra__kubernetes__in_cluster") or None + self.in_cluster, self.conn_extras.get("extra__kubernetes__in_cluster") or None ) cluster_context = self._coalesce_param( - self.cluster_context, extras.get("extra__kubernetes__cluster_context") or None + self.cluster_context, self.conn_extras.get("extra__kubernetes__cluster_context") or None ) kubeconfig_path = self._coalesce_param( - self.config_file, extras.get("extra__kubernetes__kube_config_path") or None + self.config_file, self.conn_extras.get("extra__kubernetes__kube_config_path") or None ) - kubeconfig = extras.get("extra__kubernetes__kube_config") or None + + kubeconfig = self.conn_extras.get("extra__kubernetes__kube_config") or None num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o]) if num_selected_configuration > 1: @@ -147,6 +200,43 @@ def get_conn(self) -> Any: "kube_config, in_cluster are mutually exclusive. " "You can only use one option at a time." ) + + disable_verify_ssl = self._coalesce_param( + self.disable_verify_ssl, _get_bool(self._get_field("disable_verify_ssl")) + ) + disable_tcp_keepalive = self._coalesce_param( + self.disable_tcp_keepalive, _get_bool(self._get_field("disable_tcp_keepalive")) + ) + + # BEGIN apply settings from core kubernetes configuration + # this section should be removed in next major release + deprecation_warnings: List[Tuple[str, Any]] = [] + if disable_verify_ssl is None and self._deprecated_core_disable_verify_ssl is True: + deprecation_warnings.append(('verify_ssl', False)) + disable_verify_ssl = self._deprecated_core_disable_verify_ssl + # by default, hook will try in_cluster first. so we only need to + # apply core airflow config and alert when False and in_cluster not otherwise set. + if in_cluster is None and self._deprecated_core_in_cluster is False: + deprecation_warnings.append(('in_cluster', self._deprecated_core_in_cluster)) + in_cluster = self._deprecated_core_in_cluster + if not cluster_context and self._deprecated_core_cluster_context: + deprecation_warnings.append(('cluster_context', self._deprecated_core_cluster_context)) + cluster_context = self._deprecated_core_cluster_context + if not kubeconfig_path and self._deprecated_core_config_file: + deprecation_warnings.append(('config_file', self._deprecated_core_config_file)) + kubeconfig_path = self._deprecated_core_config_file + if disable_tcp_keepalive is None and self._deprecated_core_disable_tcp_keepalive is True: + deprecation_warnings.append(('enable_tcp_keepalive', False)) + disable_tcp_keepalive = True + if deprecation_warnings: + self._deprecation_warning_core_param(deprecation_warnings) + # END apply settings from core kubernetes configuration + + if disable_verify_ssl is True: + _disable_verify_ssl() + if disable_tcp_keepalive is not True: + _enable_tcp_keepalive() + if in_cluster: self.log.debug("loading kube_config from: in_cluster configuration") config.load_incluster_config() @@ -316,3 +406,18 @@ def get_pod_logs( _preload_content=False, namespace=namespace if namespace else self.get_namespace(), ) + + +def _get_bool(val) -> Optional[bool]: + """ + Converts val to bool if can be done with certainty. + If we cannot infer intention we return None. + """ + if isinstance(val, bool): + return val + elif isinstance(val, str): + if val.strip().lower() == 'true': + return True + elif val.strip().lower() == 'false': + return False + return None diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index fba88f80e17de..69f120e823b3b 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -25,8 +25,9 @@ from kubernetes.client import CoreV1Api, models as k8s +from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.kubernetes import kube_client, pod_generator +from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator from airflow.kubernetes.secret import Secret from airflow.models import BaseOperator @@ -142,6 +143,7 @@ class KubernetesPodOperator(BaseOperator): :param priority_class_name: priority class name for the launched Pod :param termination_grace_period: Termination grace period if task killed in UI, defaults to kubernetes default + :param: kubernetes_conn_id: To retrieve credentials for your k8s cluster from an Airflow connection """ BASE_CONTAINER_NAME = 'base' @@ -209,7 +211,6 @@ def __init__( if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") super().__init__(resources=None, **kwargs) - self.kubernetes_conn_id = kubernetes_conn_id self.do_xcom_push = do_xcom_push self.image = image @@ -324,19 +325,20 @@ def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool def pod_manager(self) -> PodManager: return PodManager(kube_client=self.client) - @cached_property - def client(self) -> CoreV1Api: - if self.kubernetes_conn_id: - hook = KubernetesHook(conn_id=self.kubernetes_conn_id) - return hook.core_v1_client - - kwargs: Dict[str, Any] = dict( - cluster_context=self.cluster_context, + def get_hook(self): + hook = KubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, config_file=self.config_file, + cluster_context=self.cluster_context, ) - if self.in_cluster is not None: - kwargs.update(in_cluster=self.in_cluster) - return kube_client.get_kube_client(**kwargs) + self._patch_deprecated_k8s_settings(hook) + return hook + + @cached_property + def client(self) -> CoreV1Api: + hook = self.get_hook() + return hook.core_v1_client def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]: """Returns an already-running pod for this task instance if one exists.""" @@ -573,6 +575,39 @@ def dry_run(self) -> None: pod = self.build_pod_request_obj() print(yaml.dump(prune_dict(pod.to_dict(), mode='strict'))) + def _patch_deprecated_k8s_settings(self, hook: KubernetesHook): + """ + Here we read config from core Airflow config [kubernetes] section. + In a future release we will stop looking at this section and require users + to use Airflow connections to configure KPO. + + When we find values there that we need to apply on the hook, we patch special + hook attributes here. + """ + + # default for enable_tcp_keepalive is True; patch if False + if conf.getboolean('kubernetes', 'enable_tcp_keepalive') is False: + hook._deprecated_core_disable_tcp_keepalive = True + + # default verify_ssl is True; patch if False. + if conf.getboolean('kubernetes', 'verify_ssl') is False: + hook._deprecated_core_disable_verify_ssl = True + + # default for in_cluster is True; patch if False and no KPO param. + conf_in_cluster = conf.getboolean('kubernetes', 'in_cluster') + if self.in_cluster is None and conf_in_cluster is False: + hook._deprecated_core_in_cluster = conf_in_cluster + + # there's no default for cluster context; if we get something (and no KPO param) patch it. + conf_cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None) + if not self.cluster_context and conf_cluster_context: + hook._deprecated_core_cluster_context = conf_cluster_context + + # there's no default for config_file; if we get something (and no KPO param) patch it. + conf_config_file = conf.get('kubernetes', 'config_file', fallback=None) + if not self.config_file and conf_config_file: + hook._deprecated_core_config_file = conf_config_file + class _suppress(AbstractContextManager): """ diff --git a/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst b/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst index c0a8539832a4c..3d0dca268e97e 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst @@ -58,12 +58,17 @@ Kube config (JSON format) Namespace Default Kubernetes namespace for the connection. -When specifying the connection in environment variable you should specify -it using URI syntax. +Cluster context + When using a kube config, can specify which context to use. -Note that all components of the URI should be URL-encoded. +Disable verify SSL + Can optionally disable SSL certificate verification. By default SSL is verified. -For example: +Disable TCP keepalive + TCP keepalive is a feature (enabled by default) that tries to keep long-running connections + alive. Set this parameter to True to disable this feature. + +Example storing connection in env var using URI format: .. code-block:: bash diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 56705c7a9d6d1..49928274517ac 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -33,9 +33,9 @@ from kubernetes.client.rest import ApiException from airflow.exceptions import AirflowException -from airflow.kubernetes import kube_client from airflow.kubernetes.secret import Secret from airflow.models import DAG, XCOM_RETURN_KEY, DagRun, TaskInstance +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -43,6 +43,9 @@ from airflow.utils.types import DagRunType from airflow.version import version as airflow_version +HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" +POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" + def create_context(task): dag = DAG(dag_id="dag") @@ -123,7 +126,8 @@ def setUp(self): } def tearDown(self) -> None: - client = kube_client.get_kube_client(in_cluster=False) + hook = KubernetesHook(conn_id=None, in_cluster=False) + client = hook.core_v1_client client.delete_collection_namespaced_pod(namespace="default") import time @@ -632,10 +636,12 @@ def test_xcom_push(self, xcom_push): self.expected_pod['spec']['containers'].append(container) assert self.expected_pod == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_envs_from_secrets(self, mock_client, await_pod_completion_mock, create_pod): + @mock.patch(f"{POD_MANAGER_CLASS}.create_pod") + @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_envs_from_secrets(self, await_pod_completion_mock, create_pod): + # todo: This isn't really a system test + # GIVEN secret_ref = 'secret_name' @@ -696,6 +702,7 @@ def test_env_vars(self): assert self.expected_pod == actual_pod def test_pod_template_file_system(self): + """Note: this test requires that you have a namespace ``mem-example`` in your cluster.""" fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml' k = KubernetesPodOperator( task_id="task" + self.get_current_task_name(), @@ -872,11 +879,12 @@ def test_init_container(self): ] assert self.expected_pod == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_pod_template_file(self, mock_client, await_pod_completion_mock, create_mock, extract_xcom_mock): + @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom") + @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") + @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock): + # todo: This isn't really a system test extract_xcom_mock.return_value = '{}' path = sys.path[0] + '/tests/kubernetes/pod.yaml' k = KubernetesPodOperator( @@ -958,11 +966,15 @@ def test_pod_template_file(self, mock_client, await_pod_completion_mock, create_ del actual_pod['metadata']['labels']['airflow_version'] assert expected_dict == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_pod_priority_class_name(self, mock_client, await_pod_completion_mock, create_mock): - """Test ability to assign priorityClassName to pod""" + @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") + @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_pod_priority_class_name(self, await_pod_completion_mock): + """ + Test ability to assign priorityClassName to pod + + todo: This isn't really a system test + """ priority_class_name = "medium-test" k = KubernetesPodOperator( @@ -1002,10 +1014,10 @@ def test_pod_name(self): do_xcom_push=False, ) - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - def test_on_kill(self, await_pod_completion_mock): - - client = kube_client.get_kube_client(in_cluster=False) + @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion", new=MagicMock) + def test_on_kill(self): + hook = KubernetesHook(conn_id=None, in_cluster=False) + client = hook.core_v1_client name = "test" namespace = "default" k = KubernetesPodOperator( @@ -1032,7 +1044,8 @@ def test_on_kill(self, await_pod_completion_mock): client.read_namespaced_pod(name=name, namespace=namespace) def test_reattach_failing_pod_once(self): - client = kube_client.get_kube_client(in_cluster=False) + hook = KubernetesHook(conn_id=None, in_cluster=False) + client = hook.core_v1_client name = "test" namespace = "default" @@ -1056,9 +1069,7 @@ def get_op(): context = create_context(k) # launch pod - with mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion" - ) as await_pod_completion_mock: + with mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") as await_pod_completion_mock: pod_mock = MagicMock() pod_mock.status.phase = 'Succeeded' @@ -1082,9 +1093,7 @@ def get_op(): # `create_pod` should not be called because there's a pod there it should find # should use the found pod and patch as "already_checked" (in failure block) - with mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod" - ) as create_mock: + with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock: with pytest.raises(AirflowException): k.execute(context) pod = client.read_namespaced_pod(name=name, namespace=namespace) @@ -1096,9 +1105,7 @@ def get_op(): # `create_pod` should be called because though there's still a pod to be found, # it will be `already_checked` - with mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod" - ) as create_mock: + with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock: with pytest.raises(AirflowException): k.execute(context) create_mock.assert_called_once() diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py index 21a3bda7d331a..5a4efc73d4383 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py +++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py @@ -29,13 +29,13 @@ from kubernetes.client.rest import ApiException from airflow.exceptions import AirflowException -from airflow.kubernetes import kube_client from airflow.kubernetes.pod import Port from airflow.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv from airflow.kubernetes.secret import Secret from airflow.kubernetes.volume import Volume from airflow.kubernetes.volume_mount import VolumeMount from airflow.models import DAG, DagRun, TaskInstance +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -45,6 +45,8 @@ # noinspection DuplicatedCode +HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" + def create_context(task): dag = DAG(dag_id="dag") @@ -121,13 +123,14 @@ def setUp(self): } def tearDown(self): - client = kube_client.get_kube_client(in_cluster=False) + hook = KubernetesHook(conn_id=None, in_cluster=False) + client = hook.core_v1_client client.delete_collection_namespaced_pod(namespace="default") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_image_pull_secrets_correctly_set(self, mock_client, await_pod_completion_mock, create_mock): + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_image_pull_secrets_correctly_set(self, await_pod_completion_mock, create_mock): fake_pull_secrets = "fakeSecret" k = KubernetesPodOperator( namespace='default', @@ -461,8 +464,8 @@ def test_xcom_push(self): @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_envs_from_configmaps(self, mock_monitor, mock_start): # GIVEN configmap = 'test-configmap' # WHEN @@ -490,8 +493,8 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_envs_from_secrets(self, mock_client, await_pod_completion_mock, create_mock): + @mock.patch(HOOK_CLASS, new=MagicMock) + def test_envs_from_secrets(self, await_pod_completion_mock, create_mock): # GIVEN secret_ref = 'secret_name' secrets = [Secret('env', None, secret_ref)] diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 5194061e7a9c0..572f6e2890d25 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -21,7 +21,7 @@ import os import tempfile from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import kubernetes import pytest @@ -34,6 +34,7 @@ from tests.test_utils.db import clear_db_connections KUBE_CONFIG_PATH = os.getenv('KUBECONFIG', '~/.kube/config') +HOOK_MODULE = "airflow.providers.cncf.kubernetes.hooks.kubernetes" class TestKubernetesHook: @@ -41,9 +42,9 @@ class TestKubernetesHook: def setup_class(cls) -> None: for conn_id, extra in [ ('in_cluster', {'extra__kubernetes__in_cluster': True}), + ('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}), ('kube_config', {'extra__kubernetes__kube_config': '{"test": "kube"}'}), ('kube_config_path', {'extra__kubernetes__kube_config_path': 'path/to/file'}), - ('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}), ('kube_config_empty', {'extra__kubernetes__kube_config': ''}), ('kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}), ('kube_config_empty', {'extra__kubernetes__kube_config': ''}), @@ -52,6 +53,10 @@ def setup_class(cls) -> None: ('context', {'extra__kubernetes__cluster_context': 'my-context'}), ('with_namespace', {'extra__kubernetes__namespace': 'mock_namespace'}), ('default_kube_config', {}), + ('disable_verify_ssl', {'extra__kubernetes__disable_verify_ssl': True}), + ('disable_verify_ssl_empty', {'extra__kubernetes__disable_verify_ssl': ''}), + ('disable_tcp_keepalive', {'extra__kubernetes__disable_tcp_keepalive': True}), + ('disable_tcp_keepalive_empty', {'extra__kubernetes__disable_tcp_keepalive': ''}), ]: db.merge_conn(Connection(conn_type='kubernetes', conn_id=conn_id, extra=json.dumps(extra))) @@ -76,7 +81,7 @@ def teardown_class(cls) -> None: @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch("kubernetes.config.incluster_config.InClusterConfigLoader") - @patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client") + @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client") def test_in_cluster_connection( self, mock_get_default_client, @@ -131,6 +136,70 @@ def test_get_default_client( mock_loader.assert_not_called() assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + @pytest.mark.parametrize( + 'disable_verify_ssl, conn_id, disable_called', + ( + (True, None, True), + (None, None, False), + (False, None, False), + (None, 'disable_verify_ssl', True), + (True, 'disable_verify_ssl', True), + (False, 'disable_verify_ssl', False), + (None, 'disable_verify_ssl_empty', False), + (True, 'disable_verify_ssl_empty', True), + (False, 'disable_verify_ssl_empty', False), + ), + ) + @patch("kubernetes.config.incluster_config.InClusterConfigLoader", new=MagicMock()) + @patch(f"{HOOK_MODULE}._disable_verify_ssl") + def test_disable_verify_ssl( + self, + mock_disable, + disable_verify_ssl, + conn_id, + disable_called, + ): + """ + Verifies whether disable verify ssl is called depending on combination of hook param and + connection extra. Hook param should beat extra. + """ + kubernetes_hook = KubernetesHook(conn_id=conn_id, disable_verify_ssl=disable_verify_ssl) + api_conn = kubernetes_hook.get_conn() + assert mock_disable.called is disable_called + assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + + @pytest.mark.parametrize( + 'disable_tcp_keepalive, conn_id, expected', + ( + (True, None, False), + (None, None, True), + (False, None, True), + (None, 'disable_tcp_keepalive', False), + (True, 'disable_tcp_keepalive', False), + (False, 'disable_tcp_keepalive', True), + (None, 'disable_tcp_keepalive_empty', True), + (True, 'disable_tcp_keepalive_empty', False), + (False, 'disable_tcp_keepalive_empty', True), + ), + ) + @patch("kubernetes.config.incluster_config.InClusterConfigLoader", new=MagicMock()) + @patch(f"{HOOK_MODULE}._enable_tcp_keepalive") + def test_disable_tcp_keepalive( + self, + mock_enable, + disable_tcp_keepalive, + conn_id, + expected, + ): + """ + Verifies whether enable tcp keepalive is called depending on combination of hook + param and connection extra. Hook param should beat extra. + """ + kubernetes_hook = KubernetesHook(conn_id=conn_id, disable_tcp_keepalive=disable_tcp_keepalive) + api_conn = kubernetes_hook.get_conn() + assert mock_enable.called is expected + assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + @pytest.mark.parametrize( 'config_path_param, conn_id, call_path', ( @@ -239,6 +308,61 @@ def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader): assert isinstance(hook.api_client, kubernetes.client.ApiClient) assert isinstance(hook.get_conn(), kubernetes.client.ApiClient) + @patch(f"{HOOK_MODULE}._disable_verify_ssl") + @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client", new=MagicMock) + def test_patch_core_settings_verify_ssl(self, mock_disable_verify_ssl): + hook = KubernetesHook() + hook.get_conn() + mock_disable_verify_ssl.assert_not_called() + mock_disable_verify_ssl.reset_mock() + hook._deprecated_core_disable_verify_ssl = True + hook.get_conn() + mock_disable_verify_ssl.assert_called() + + @patch(f"{HOOK_MODULE}._enable_tcp_keepalive") + @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client", new=MagicMock) + def test_patch_core_settings_tcp_keepalive(self, mock_enable_tcp_keepalive): + hook = KubernetesHook() + hook.get_conn() + mock_enable_tcp_keepalive.assert_called() + mock_enable_tcp_keepalive.reset_mock() + hook._deprecated_core_disable_tcp_keepalive = True + hook.get_conn() + mock_enable_tcp_keepalive.assert_not_called() + + @patch("kubernetes.config.kube_config.KubeConfigLoader", new=MagicMock()) + @patch("kubernetes.config.kube_config.KubeConfigMerger", new=MagicMock()) + @patch("kubernetes.config.incluster_config.InClusterConfigLoader") + @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client") + def test_patch_core_settings_in_cluster(self, mock_get_default_client, mock_in_cluster_loader): + hook = KubernetesHook(conn_id=None) + hook.get_conn() + mock_in_cluster_loader.assert_not_called() + mock_in_cluster_loader.reset_mock() + hook._deprecated_core_in_cluster = False + hook.get_conn() + mock_in_cluster_loader.assert_not_called() + mock_get_default_client.assert_called() + + @pytest.mark.parametrize( + 'key, key_val, attr, attr_val', + [ + ('in_cluster', False, '_deprecated_core_in_cluster', False), + ('verify_ssl', False, '_deprecated_core_disable_verify_ssl', True), + ('cluster_context', 'hi', '_deprecated_core_cluster_context', 'hi'), + ('config_file', '/path/to/file.txt', '_deprecated_core_config_file', '/path/to/file.txt'), + ('enable_tcp_keepalive', False, '_deprecated_core_disable_tcp_keepalive', True), + ], + ) + @patch("kubernetes.config.incluster_config.InClusterConfigLoader", new=MagicMock()) + @patch("kubernetes.config.kube_config.KubeConfigLoader", new=MagicMock()) + @patch("kubernetes.config.kube_config.KubeConfigMerger", new=MagicMock()) + def test_core_settings_warnings(self, key, key_val, attr, attr_val): + hook = KubernetesHook(conn_id=None) + setattr(hook, attr, attr_val) + with pytest.warns(DeprecationWarning, match=rf'.*Airflow settings.*\n.*{key}={key_val!r}.*'): + hook.get_conn() + class TestKubernetesHookIncorrectConfiguration: @pytest.mark.parametrize( diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index e70bf883267de..cac1b8a1b587a 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -29,9 +29,12 @@ from airflow.utils.session import create_session from airflow.utils.types import DagRunType from tests.test_utils import db +from tests.test_utils.config import conf_vars DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod" +POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" +HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" @pytest.fixture(scope='function', autouse=True) @@ -66,28 +69,22 @@ def create_context(task, persist_to_db=False): } -POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" - - class TestKubernetesPodOperator: @pytest.fixture(autouse=True) def setup(self, dag_maker): self.create_pod_patch = mock.patch(f"{POD_MANAGER_CLASS}.create_pod") self.await_pod_patch = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start") self.await_pod_completion_patch = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") - self.client_patch = mock.patch("airflow.kubernetes.kube_client.get_kube_client") + self.hook_patch = mock.patch(HOOK_CLASS) self.create_mock = self.create_pod_patch.start() self.await_start_mock = self.await_pod_patch.start() self.await_pod_mock = self.await_pod_completion_patch.start() - self.client_mock = self.client_patch.start() + self.hook_mock = self.hook_patch.start() self.dag_maker = dag_maker yield - self.create_pod_patch.stop() - self.await_pod_patch.stop() - self.await_pod_completion_patch.stop() - self.client_patch.stop() + mock.patch.stopall() def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V1Pod: with self.dag_maker(dag_id='dag') as dag: @@ -127,9 +124,9 @@ def test_config_path(self): remote_pod_mock = MagicMock() remote_pod_mock.status.phase = 'Succeeded' self.await_pod_mock.return_value = remote_pod_mock - self.client_mock.list_namespaced_pod.return_value = [] self.run_pod(k) - self.client_mock.assert_called_once_with( + self.hook_mock.assert_called_once_with( + conn_id=None, in_cluster=False, cluster_context="default", config_file=file_path, @@ -226,8 +223,7 @@ def test_find_pod_labels(self): do_xcom_push=False, ) self.run_pod(k) - self.client_mock.return_value.list_namespaced_pod.assert_called_once() - _, kwargs = self.client_mock.return_value.list_namespaced_pod.call_args + _, kwargs = k.client.list_namespaced_pod.call_args assert kwargs['label_selector'] == ( 'dag_id=dag,kubernetes_pod_operator=True,run_id=test,task_id=task,' 'already_checked!=True,!airflow-worker' @@ -576,7 +572,7 @@ def test_describes_pod_on_failure(self, await_container_mock, fetch_container_mo context = create_context(k) k.execute(context=context) - assert not self.client_mock.return_value.read_namespaced_pod.called + assert k.client.read_namespaced_pod.called is False @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_container_completion") @@ -794,8 +790,8 @@ def test_previous_pods_ignored_for_reattached(self): task_id="task", ) self.run_pod(k) - self.client_mock.return_value.list_namespaced_pod.assert_called_once() - _, kwargs = self.client_mock.return_value.list_namespaced_pod.call_args + k.client.list_namespaced_pod.assert_called_once() + _, kwargs = k.client.list_namespaced_pod.call_args assert 'already_checked!=True' in kwargs['label_selector'] @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.delete_pod") @@ -842,6 +838,29 @@ def test_mark_checked_if_not_deleted(self, mock_patch_already_checked, mock_dele mock_patch_already_checked.assert_called_once() mock_delete_pod.assert_not_called() + @pytest.mark.parametrize( + 'key, value, attr, patched_value', + [ + ('verify_ssl', 'False', '_deprecated_core_disable_verify_ssl', True), + ('in_cluster', 'False', '_deprecated_core_in_cluster', False), + ('cluster_context', 'hi', '_deprecated_core_cluster_context', 'hi'), + ('config_file', '/path/to/file.txt', '_deprecated_core_config_file', '/path/to/file.txt'), + ('enable_tcp_keepalive', 'False', '_deprecated_core_disable_tcp_keepalive', True), + ], + ) + def test_patch_core_settings(self, key, value, attr, patched_value): + # first verify the behavior for the default value + # the hook attr should be None + op = KubernetesPodOperator(task_id='abc', name='hi') + self.hook_patch.stop() + hook = op.get_hook() + assert getattr(hook, attr) is None + # now check behavior with a non-default value + with conf_vars({('kubernetes', key): value}): + op = KubernetesPodOperator(task_id='abc', name='hi') + hook = op.get_hook() + assert getattr(hook, attr) == patched_value + def test__suppress(): with mock.patch('logging.Logger.error') as mock_error: