From fd9e7f40b3ad2e589e907069c7a5c7af7dc652cb Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Sat, 15 Aug 2020 19:58:03 +0530 Subject: [PATCH 01/17] #9941 Add extra links for google dataproc --- .../google/cloud/operators/dataproc.py | 36 +++++++++- .../google/cloud/operators/test_dataproc.py | 69 ++++++++++++++++++- 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index c8ee5e35a69f1..710da8fd99659 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -35,12 +35,40 @@ from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils import timezone from airflow.utils.decorators import apply_defaults +# pylint: disable=line-too-long +DATAPROC_JOB_LOG_LINK = "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}" # noqa: E501 +DATAPROC_CLUSTER_JOBS_LINK = "https://console.cloud.google.com/dataproc/clusters/{cluster_name}/jobs?region={region}&project={project_id}" # noqa: E501 + + +class DataprocJobLink(BaseOperatorLink): + """ + Helper class for constructing Dataproc Job link + """ + name = "Dataproc Job" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + if job_id: + return DATAPROC_JOB_LOG_LINK.format( + job_id=job_id, + region=operator.location, + project=operator.project_id + ) + + return DATAPROC_CLUSTER_JOBS_LINK.format( + cluster_name=operator.cluster_name, + region=operator.location, + project=operator.project_id + ) + # pylint: disable=too-many-instance-attributes class ClusterGenerator: @@ -1871,6 +1899,10 @@ class DataprocSubmitJobOperator(BaseOperator): template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id') template_fields_renderers = {"job": "json"} + operator_extra_links = ( + DataprocJobLink(), + ) + @apply_defaults def __init__( self, @@ -1934,6 +1966,8 @@ def on_kill(self): if self.job_id and self.cancel_on_kill: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location) + context['task_instance'].xcom_push(key='job_id', value=job_id) + class DataprocUpdateClusterOperator(BaseOperator): """ diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index fb2ceef1f98e7..74c2b8bd04039 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -25,6 +25,7 @@ from google.api_core.retry import Retry from airflow import AirflowException +from airflow.models import DAG, DagBag, TaskInstance, XCom from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, DataprocCreateClusterOperator, @@ -32,6 +33,7 @@ DataprocDeleteClusterOperator, DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, + DataprocJobLink, DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator, DataprocSubmitHiveJobOperator, @@ -40,8 +42,10 @@ DataprocSubmitPySparkJobOperator, DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator, - DataprocUpdateClusterOperator, + DataprocUpdateClusterOperator ) +from airflow.settings import Session +from airflow.utils.session import provide_session from airflow.version import version as airflow_version cluster_params = inspect.signature(ClusterGenerator.__init__).parameters @@ -171,6 +175,8 @@ }, "jobs": [{"step_id": "pig_job_1", "pig_job": {}}], } +TEST_DAG_ID = 'test-dataproc-operators' +DEFAULT_DATE = datetime(2020, 1, 1) def assert_warning(msg: str, warnings): @@ -550,6 +556,23 @@ def test_execute(self, mock_hook): class TestDataprocSubmitJobOperator(unittest.TestCase): + + def setUp(self): + self.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + self.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + + def tearDown(self): + session = Session() + session.query(TaskInstance).filter_by( + dag_id=TEST_DAG_ID).delete() + session.commit() + session.close() + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): job = {} @@ -653,6 +676,50 @@ def test_on_kill(self, mock_hook): project_id=GCP_PROJECT, location=GCP_LOCATION, job_id=job_id ) + @provide_session + def test_operator_extra_links(self, session): + job = {} + job_id = 'test_job_id_12345' + execution_date = datetime(2020, 7, 20) + + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + job=job, + gcp_conn_id=GCP_CONN_ID, + cluster_name=CLUSTER_NAME + ) + self.dag.clear() + session.query(XCom).deete() + + ti = TaskInstance( + task=op, + execution_date=execution_date + ) + + self.assertEqual( + # pylint: disable=line-too-long + 'https://console.cloud.google.com/dataproc/clusters/{cluster_name}/job?region={region}&project={project_id}'.format( # noqa: E501 + cluster_name=CLUSTER_NAME, + region=GCP_LOCATION, + project_id=GCP_PROJECT + ), + op.get_extra_links(execution_date, DataprocJobLink.name) + ) + + ti.xcom_push(key='job_id', value=job_id) + + self.assertEqual( + # pylint: disable=line-too-long + 'https://console.cloud.google.com/dataproc/jobs/{job_id}/?region={region}&project={project_id}'.format( # noqa: E501 + job_id=job_id, + region=GCP_LOCATION, + project_id=GCP_PROJECT + ), + op.get_extra_links(execution_date, DataprocJobLink.name) + ) + class TestDataprocUpdateClusterOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) From 2305078970a1207f93401e82b51f6eb0cdcee849 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Sat, 15 Aug 2020 23:25:18 +0530 Subject: [PATCH 02/17] #9941 Update DataprocJobLink for DataprocSubmitJobOperator and DataprocJobBaseOperator --- .../google/cloud/operators/dataproc.py | 30 +++++++++----- .../google/cloud/operators/test_dataproc.py | 41 ++++++++++--------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 710da8fd99659..eb0e0ce9527af 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -44,7 +44,6 @@ # pylint: disable=line-too-long DATAPROC_JOB_LOG_LINK = "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}" # noqa: E501 -DATAPROC_CLUSTER_JOBS_LINK = "https://console.cloud.google.com/dataproc/clusters/{cluster_name}/jobs?region={region}&project={project_id}" # noqa: E501 class DataprocJobLink(BaseOperatorLink): @@ -55,19 +54,19 @@ class DataprocJobLink(BaseOperatorLink): def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) - job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') - if job_id: + if isinstance(operator, DataprocJobBaseOperator): + job_conf = ti.xcom_pull(task_ids=operator.task_id, key='job_conf') return DATAPROC_JOB_LOG_LINK.format( - job_id=job_id, - region=operator.location, - project=operator.project_id - ) + job_id=job_conf['job_id'], + region=operator.region, + project_id=job_conf['project_id']) if job_conf else '' - return DATAPROC_CLUSTER_JOBS_LINK.format( - cluster_name=operator.cluster_name, + job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + return DATAPROC_JOB_LOG_LINK.format( + job_id=job_id, region=operator.location, - project=operator.project_id - ) + project_id=operator.project_id + ) if job_id else '' # pylint: disable=too-many-instance-attributes @@ -959,6 +958,10 @@ class DataprocJobBaseOperator(BaseOperator): job_type = "" + operator_extra_links = ( + DataprocJobLink(), + ) + @apply_defaults def __init__( self, @@ -1033,12 +1036,17 @@ def execute(self, context): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) + context['task_instance'].xcom_push(key='job_conf', value={ + 'job_id': job_id, + 'project_id': self.project_id + }) if not self.asynchronous: self.log.info('Waiting for job %s to complete', job_id) self.hook.wait_for_job(job_id=job_id, location=self.region, project_id=self.project_id) self.log.info('Job %s completed successfully.', job_id) return job_id + else: raise AirflowException("Create a job template before") diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 74c2b8bd04039..4aca8df1c386c 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -19,6 +19,7 @@ import unittest from datetime import datetime from unittest import mock +from unittest.mock import MagicMock import pytest from google.api_core.exceptions import AlreadyExists, NotFound @@ -565,6 +566,7 @@ def setUp(self): 'owner': 'airflow', 'start_date': DEFAULT_DATE }) + self.mock_context = MagicMock() def tearDown(self): session = Session() @@ -592,7 +594,7 @@ def test_execute(self, mock_hook): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( @@ -687,11 +689,10 @@ def test_operator_extra_links(self, session): location=GCP_LOCATION, project_id=GCP_PROJECT, job=job, - gcp_conn_id=GCP_CONN_ID, - cluster_name=CLUSTER_NAME + gcp_conn_id=GCP_CONN_ID ) self.dag.clear() - session.query(XCom).deete() + session.query(XCom).delete() ti = TaskInstance( task=op, @@ -699,25 +700,19 @@ def test_operator_extra_links(self, session): ) self.assertEqual( - # pylint: disable=line-too-long - 'https://console.cloud.google.com/dataproc/clusters/{cluster_name}/job?region={region}&project={project_id}'.format( # noqa: E501 - cluster_name=CLUSTER_NAME, - region=GCP_LOCATION, - project_id=GCP_PROJECT - ), - op.get_extra_links(execution_date, DataprocJobLink.name) + op.get_extra_links(execution_date, DataprocJobLink.name), '' ) ti.xcom_push(key='job_id', value=job_id) self.assertEqual( + op.get_extra_links(execution_date, DataprocJobLink.name), # pylint: disable=line-too-long - 'https://console.cloud.google.com/dataproc/jobs/{job_id}/?region={region}&project={project_id}'.format( # noqa: E501 + 'https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}'.format( # noqa: E501 job_id=job_id, region=GCP_LOCATION, project_id=GCP_PROJECT - ), - op.get_extra_links(execution_date, DataprocJobLink.name) + ) ) @@ -854,8 +849,10 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -913,8 +910,10 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -978,8 +977,10 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) From a4cbad08308266d2bf438d35c3bdb515a21e62de Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Sun, 16 Aug 2020 02:02:27 +0530 Subject: [PATCH 03/17] #9941 Add test for DataprocSparkOperator --- .../google/cloud/operators/test_dataproc.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 4aca8df1c386c..31cad5c69627a 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -1042,6 +1042,23 @@ class TestDataProcSparkOperator(unittest.TestCase): "spark_job": {"jar_file_uris": jars, "main_class": main_class}, } + def setUp(self): + self.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + self.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + self.mock_context = MagicMock() + + def tearDown(self): + session = Session() + session.query(TaskInstance).filter_by( + dag_id=TEST_DAG_ID).delete() + session.commit() + session.close() + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_deprecation_warning(self, mock_hook): with pytest.warns(DeprecationWarning) as warnings: @@ -1067,6 +1084,46 @@ def test_execute(self, mock_hook, mock_uuid): job = op.generate_job() assert self.job == job + @provide_session + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_operator_extra_links(self, mock_hook, session): + job_id = 'test_spark_job_12345' + execution_date = datetime(2020, 7, 20) + + op = DataprocSubmitSparkJobOperator( + task_id=TASK_ID, + region=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + main_class=self.main_class, + dataproc_jars=self.jars + ) + self.dag.clear() + session.query(XCom).delete() + + ti = TaskInstance( + task=op, + execution_date=execution_date + ) + + self.assertEqual( + op.get_extra_links(execution_date, DataprocJobLink.name), '' + ) + + ti.xcom_push(key='job_conf', value={ + 'job_id': job_id, + 'project_id': GCP_PROJECT + }) + + self.assertEqual( + op.get_extra_links(execution_date, DataprocJobLink.name), + # pylint: disable=line-too-long + "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}".format( # noqa: E501 + job_id=job_id, + region=GCP_LOCATION, + project_id=GCP_PROJECT + ) + ) + class TestDataProcHadoopOperator(unittest.TestCase): args = ["wordcount", "gs://pub/shakespeare/rose.txt"] From baa2d2634ddf4885072eb53f9be895c571d2d21d Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Sun, 16 Aug 2020 15:13:09 +0530 Subject: [PATCH 04/17] #9941 Add DataprocJobLink to serializer --- airflow/serialization/serialized_objects.py | 7 ++ .../google/cloud/operators/test_dataproc.py | 88 +++++++++++++++---- 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8a6fdc89d2062..50af812682fb0 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -72,6 +72,13 @@ "airflow.sensors.external_task_sensor.ExternalTaskSensorLink", } +BUILTIN_OPERATOR_EXTRA_LINKS: List[str] = [ + "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink", + "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink", + "airflow.providers.google.cloud.operators.dataproc.DataprocJobLink", + "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink", + "airflow.providers.qubole.operators.qubole.QDSLink" +] @cache def get_operator_extra_links(): diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 31cad5c69627a..7dee8e9ff9910 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -45,6 +45,7 @@ DataprocSubmitSparkSqlJobOperator, DataprocUpdateClusterOperator ) +from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import Session from airflow.utils.session import provide_session from airflow.version import version as airflow_version @@ -683,36 +684,63 @@ def test_operator_extra_links(self, session): job = {} job_id = 'test_job_id_12345' execution_date = datetime(2020, 7, 20) + # pylint: disable=line-too-long + expected_extra_link = 'https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}'.format( # noqa: E501 + job_id=job_id, + region=GCP_LOCATION, + project_id=GCP_PROJECT + ) op = DataprocSubmitJobOperator( task_id=TASK_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, job=job, - gcp_conn_id=GCP_CONN_ID + gcp_conn_id=GCP_CONN_ID, + dag=self.dag ) self.dag.clear() session.query(XCom).delete() + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + self.assertEqual( + serialized_dag['dag']['tasks'][0]['_operator_extra_links'], + [{'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink': {}}] + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + ti = TaskInstance( task=op, execution_date=execution_date ) + # Assert operator link is empty when no XCom push occured self.assertEqual( - op.get_extra_links(execution_date, DataprocJobLink.name), '' + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' + ) + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' ) ti.xcom_push(key='job_id', value=job_id) + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(execution_date, DataprocJobLink.name), + expected_extra_link + ) + # Assert operator links after execution self.assertEqual( op.get_extra_links(execution_date, DataprocJobLink.name), - # pylint: disable=line-too-long - 'https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}'.format( # noqa: E501 - job_id=job_id, - region=GCP_LOCATION, - project_id=GCP_PROJECT - ) + expected_extra_link ) @@ -1089,24 +1117,50 @@ def test_execute(self, mock_hook, mock_uuid): def test_operator_extra_links(self, mock_hook, session): job_id = 'test_spark_job_12345' execution_date = datetime(2020, 7, 20) + # pylint: disable=line-too-long + expected_extra_link = "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}".format( # noqa: E501 + job_id=job_id, + region=GCP_LOCATION, + project_id=GCP_PROJECT + ) op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, region=GCP_LOCATION, gcp_conn_id=GCP_CONN_ID, main_class=self.main_class, - dataproc_jars=self.jars + dataproc_jars=self.jars, + dag=self.dag ) self.dag.clear() session.query(XCom).delete() + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag['dag']['tasks'][0]['_operator_extra_links'], + [{'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink': {}}] + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + ti = TaskInstance( task=op, execution_date=execution_date ) + # Assert operator link is empty when no XCom push occured self.assertEqual( - op.get_extra_links(execution_date, DataprocJobLink.name), '' + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' + ) + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' ) ti.xcom_push(key='job_conf', value={ @@ -1114,14 +1168,16 @@ def test_operator_extra_links(self, mock_hook, session): 'project_id': GCP_PROJECT }) + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(execution_date, DataprocJobLink.name), + expected_extra_link + ) + + # Assert operator links after task execution self.assertEqual( op.get_extra_links(execution_date, DataprocJobLink.name), - # pylint: disable=line-too-long - "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}".format( # noqa: E501 - job_id=job_id, - region=GCP_LOCATION, - project_id=GCP_PROJECT - ) + expected_extra_link ) From f20a10f0a5760947a890d6870d0bf71fa817f1e7 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Mon, 17 Aug 2020 21:04:38 +0530 Subject: [PATCH 05/17] #9941 Fix tests for DataprocJobBaseOperators --- .../google/cloud/operators/dataproc.py | 37 ++--- .../google/cloud/operators/test_dataproc.py | 136 +++++++++--------- 2 files changed, 87 insertions(+), 86 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index eb0e0ce9527af..3fa42c55b5fac 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -42,8 +42,8 @@ from airflow.utils import timezone from airflow.utils.decorators import apply_defaults -# pylint: disable=line-too-long -DATAPROC_JOB_LOG_LINK = "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}" # noqa: E501 +DATAPROC_JOB_LOG_LINK = \ + "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}" class DataprocJobLink(BaseOperatorLink): @@ -54,19 +54,12 @@ class DataprocJobLink(BaseOperatorLink): def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) - if isinstance(operator, DataprocJobBaseOperator): - job_conf = ti.xcom_pull(task_ids=operator.task_id, key='job_conf') - return DATAPROC_JOB_LOG_LINK.format( - job_id=job_conf['job_id'], - region=operator.region, - project_id=job_conf['project_id']) if job_conf else '' - - job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + job_conf = ti.xcom_pull(task_ids=operator.task_id, key='job_conf') return DATAPROC_JOB_LOG_LINK.format( - job_id=job_id, - region=operator.location, - project_id=operator.project_id - ) if job_id else '' + job_id=job_conf['job_id'], + region=job_conf['region'], + project_id=job_conf['project_id'] + ) if job_conf else '' # pylint: disable=too-many-instance-attributes @@ -1036,8 +1029,9 @@ def execute(self, context): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) - context['task_instance'].xcom_push(key='job_conf', value={ + self.xcom_push(context, key='job_conf', value={ 'job_id': job_id, + 'region': self.region, 'project_id': self.project_id }) @@ -1050,7 +1044,7 @@ def execute(self, context): else: raise AirflowException("Create a job template before") - def on_kill(self) -> None: + def on_kill(self): """ Callback called when the operator is killed. Cancel any running job. @@ -1118,6 +1112,10 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): ui_color = '#0273d4' job_type = 'pig_job' + operator_extra_links = ( + DataprocJobLink(), + ) + @apply_defaults def __init__( self, @@ -1165,6 +1163,7 @@ def execute(self, context): self.job_template.add_variables(self.variables) super().execute(context) + # context['task_instance'].xcom_push(key='job_id', value=self.job['reference']['job_id']) class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): @@ -1974,7 +1973,11 @@ def on_kill(self): if self.job_id and self.cancel_on_kill: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location) - context['task_instance'].xcom_push(key='job_id', value=job_id) + self.xcom_push(context, key='job_conf', value={ + 'job_id': job_id, + 'region': self.location, + 'project_id': self.project_id + }) class DataprocUpdateClusterOperator(BaseOperator): diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 7dee8e9ff9910..5cfcf28db7610 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -26,7 +26,7 @@ from google.api_core.retry import Retry from airflow import AirflowException -from airflow.models import DAG, DagBag, TaskInstance, XCom +from airflow.models import DAG, DagBag, TaskInstance from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, DataprocCreateClusterOperator, @@ -46,9 +46,8 @@ DataprocUpdateClusterOperator ) from airflow.serialization.serialized_objects import SerializedDAG -from airflow.settings import Session -from airflow.utils.session import provide_session from airflow.version import version as airflow_version +from tests.test_utils.db import clear_db_runs, clear_db_xcom cluster_params = inspect.signature(ClusterGenerator.__init__).parameters @@ -179,6 +178,11 @@ } TEST_DAG_ID = 'test-dataproc-operators' DEFAULT_DATE = datetime(2020, 1, 1) +TEST_JOB_ID = 'test-job' + +DATAPROC_JOB_LINK_EXPECTED = \ + f'https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?' \ + f'region={GCP_LOCATION}&project={GCP_PROJECT}' def assert_warning(msg: str, warnings): @@ -559,29 +563,27 @@ def test_execute(self, mock_hook): class TestDataprocSubmitJobOperator(unittest.TestCase): - def setUp(self): - self.dagbag = DagBag( + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( dag_folder='/dev/null', include_examples=False ) - self.dag = DAG(TEST_DAG_ID, default_args={ + cls.dag = DAG(TEST_DAG_ID, default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }) - self.mock_context = MagicMock() + cls.mock_context = MagicMock() - def tearDown(self): - session = Session() - session.query(TaskInstance).filter_by( - dag_id=TEST_DAG_ID).delete() - session.commit() - session.close() + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): job = {} - job_id = "job_id" mock_hook.return_value.wait_for_job.return_value = None - mock_hook.return_value.submit_job.return_value.reference.job_id = job_id + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -679,28 +681,17 @@ def test_on_kill(self, mock_hook): project_id=GCP_PROJECT, location=GCP_LOCATION, job_id=job_id ) - @provide_session - def test_operator_extra_links(self, session): - job = {} - job_id = 'test_job_id_12345' - execution_date = datetime(2020, 7, 20) - # pylint: disable=line-too-long - expected_extra_link = 'https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}'.format( # noqa: E501 - job_id=job_id, - region=GCP_LOCATION, - project_id=GCP_PROJECT - ) - + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_operator_extra_links(self, mock_hook): + mock_hook.return_value.project_id = GCP_PROJECT op = DataprocSubmitJobOperator( task_id=TASK_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, - job=job, + job={}, gcp_conn_id=GCP_CONN_ID, dag=self.dag ) - self.dag.clear() - session.query(XCom).delete() serialized_dag = SerializedDAG.to_dict(self.dag) deserialized_dag = SerializedDAG.from_dict(serialized_dag) @@ -717,7 +708,7 @@ def test_operator_extra_links(self, session): ti = TaskInstance( task=op, - execution_date=execution_date + execution_date=DEFAULT_DATE ) # Assert operator link is empty when no XCom push occured @@ -730,17 +721,26 @@ def test_operator_extra_links(self, session): deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' ) - ti.xcom_push(key='job_id', value=job_id) + ti.xcom_push(key='job_conf', value={ + 'job_id': TEST_JOB_ID, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT + }) # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(execution_date, DataprocJobLink.name), - expected_extra_link + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(execution_date, DataprocJobLink.name), - expected_extra_link + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED + ) + # Check for negative case + self.assertEqual( + op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), + '' ) @@ -1062,30 +1062,31 @@ def test_builder(self, mock_hook, mock_uuid): class TestDataProcSparkOperator(unittest.TestCase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] - job_id = "uuid_id" job = { - "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id}, + "reference": { + "project_id": GCP_PROJECT, + "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID, + }, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "spark_job": {"jar_file_uris": jars, "main_class": main_class}, } - def setUp(self): - self.dagbag = DagBag( + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( dag_folder='/dev/null', include_examples=False ) - self.dag = DAG(TEST_DAG_ID, default_args={ + cls.dag = DAG(TEST_DAG_ID, default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }) - self.mock_context = MagicMock() + cls.mock_context = MagicMock() - def tearDown(self): - session = Session() - session.query(TaskInstance).filter_by( - dag_id=TEST_DAG_ID).delete() - session.commit() - session.close() + @classmethod + def tearDown(cls): + clear_db_runs() + clear_db_xcom() @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_deprecation_warning(self, mock_hook): @@ -1098,9 +1099,9 @@ def test_deprecation_warning(self, mock_hook): @mock.patch(DATAPROC_PATH.format("uuid.uuid4")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook, mock_uuid): - mock_uuid.return_value = self.job_id + mock_uuid.return_value = TEST_JOB_ID mock_hook.return_value.project_id = GCP_PROJECT - mock_uuid.return_value = self.job_id + mock_uuid.return_value = TEST_JOB_ID op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, @@ -1112,17 +1113,9 @@ def test_execute(self, mock_hook, mock_uuid): job = op.generate_job() assert self.job == job - @provide_session @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_operator_extra_links(self, mock_hook, session): - job_id = 'test_spark_job_12345' - execution_date = datetime(2020, 7, 20) - # pylint: disable=line-too-long - expected_extra_link = "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}".format( # noqa: E501 - job_id=job_id, - region=GCP_LOCATION, - project_id=GCP_PROJECT - ) + def test_operator_extra_links(self, mock_hook): + mock_hook.return_value.project_id = GCP_PROJECT op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, @@ -1132,8 +1125,6 @@ def test_operator_extra_links(self, mock_hook, session): dataproc_jars=self.jars, dag=self.dag ) - self.dag.clear() - session.query(XCom).delete() serialized_dag = SerializedDAG.to_dict(self.dag) deserialized_dag = SerializedDAG.from_dict(serialized_dag) @@ -1150,7 +1141,7 @@ def test_operator_extra_links(self, mock_hook, session): ti = TaskInstance( task=op, - execution_date=execution_date + execution_date=DEFAULT_DATE ) # Assert operator link is empty when no XCom push occured @@ -1164,20 +1155,27 @@ def test_operator_extra_links(self, mock_hook, session): ) ti.xcom_push(key='job_conf', value={ - 'job_id': job_id, + 'job_id': TEST_JOB_ID, + 'region': GCP_LOCATION, 'project_id': GCP_PROJECT }) + # Assert operator links after task execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED + ) + # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(execution_date, DataprocJobLink.name), - expected_extra_link + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED ) - # Assert operator links after task execution + # Assert for negative case self.assertEqual( - op.get_extra_links(execution_date, DataprocJobLink.name), - expected_extra_link + deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), + '' ) From 3e44589be1bb4dbc22ce4bee5786281423cd567a Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Tue, 18 Aug 2020 04:07:24 +0530 Subject: [PATCH 06/17] apache#9941 Add links for Cluster configuration operators --- .../google/cloud/operators/dataproc.py | 48 ++- airflow/serialization/serialized_objects.py | 1 + .../google/cloud/operators/test_dataproc.py | 276 +++++++++++++++++- 3 files changed, 314 insertions(+), 11 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 3fa42c55b5fac..18dffcae79cc3 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -42,8 +42,12 @@ from airflow.utils import timezone from airflow.utils.decorators import apply_defaults +DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc" DATAPROC_JOB_LOG_LINK = \ - "https://console.cloud.google.com/dataproc/jobs/{job_id}?region={region}&project={project_id}" + DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" +DATAPROC_CLUSTER_LINK = \ + DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?" \ + "region={region}&project={project_id}" class DataprocJobLink(BaseOperatorLink): @@ -62,6 +66,22 @@ def get_link(self, operator, dttm): ) if job_conf else '' +class DataprocClusterLink(BaseOperatorLink): + """ + Helper class for constructing Dataproc Cluster link + """ + name = "Dataproc Cluster" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key='cluster_conf') + return DATAPROC_CLUSTER_LINK.format( + cluster_name=cluster_conf['cluster_name'], + region=cluster_conf['region'], + project_id=cluster_conf['project_id'] + ) if cluster_conf else '' + + # pylint: disable=too-many-instance-attributes class ClusterGenerator: """ @@ -498,6 +518,10 @@ class DataprocCreateClusterOperator(BaseOperator): ) template_fields_renderers = {'cluster_config': 'json'} + operator_extra_links = ( + DataprocClusterLink(), + ) + @apply_defaults def __init__( # pylint: disable=too-many-arguments self, @@ -662,6 +686,11 @@ def execute(self, context) -> dict: cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) + self.xcom_push(context, key='cluster_conf', value={ + 'cluster_name': self.cluster_name, + 'region': self.region, + 'project_id': self.project_id + }) return Cluster.to_dict(cluster) @@ -714,6 +743,10 @@ class DataprocScaleClusterOperator(BaseOperator): template_fields = ['cluster_name', 'project_id', 'region', 'impersonation_chain'] + operator_extra_links = ( + DataprocClusterLink(), + ) + @apply_defaults def __init__( self, @@ -803,6 +836,11 @@ def execute(self, context) -> None: ) operation.result() self.log.info("Cluster scaling finished") + self.xcom_push(context, key='cluster_conf', value={ + 'cluster_name': self.cluster_name, + 'region': self.region, + 'project_id': self.project_id + }) class DataprocDeleteClusterOperator(BaseOperator): @@ -2033,6 +2071,9 @@ class DataprocUpdateClusterOperator(BaseOperator): """ template_fields = ('impersonation_chain', 'cluster_name') + operator_extra_links = ( + DataprocClusterLink(), + ) @apply_defaults def __init__( # pylint: disable=too-many-arguments @@ -2083,3 +2124,8 @@ def execute(self, context: Dict): ) operation.result() self.log.info("Updated %s cluster.", self.cluster_name) + self.xcom_push(context, key='cluster_conf', value={ + 'cluster_name': self.cluster_name, + 'region': self.location, + 'project_id': self.project_id + }) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 50af812682fb0..c5ed0cad83216 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -76,6 +76,7 @@ "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink", "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink", "airflow.providers.google.cloud.operators.dataproc.DataprocJobLink", + "airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink", "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink", "airflow.providers.qubole.operators.qubole.QDSLink" ] diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 5cfcf28db7610..e39f0e4f1d08f 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -29,6 +29,7 @@ from airflow.models import DAG, DagBag, TaskInstance from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, + DataprocClusterLink, DataprocCreateClusterOperator, DataprocCreateWorkflowTemplateOperator, DataprocDeleteClusterOperator, @@ -183,6 +184,9 @@ DATAPROC_JOB_LINK_EXPECTED = \ f'https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?' \ f'region={GCP_LOCATION}&project={GCP_PROJECT}' +DATAPROC_CLUSTER_LINK_EXPECTED = \ + f'https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?' \ + f'region={GCP_LOCATION}&project={GCP_PROJECT}' def assert_warning(msg: str, warnings): @@ -303,6 +307,23 @@ def test_build_with_custom_image_family(self): class TestDataprocClusterCreateOperator(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + cls.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + cls.mock_context = MagicMock() + + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() + def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -347,8 +368,10 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=self.mock_context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -381,8 +404,10 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=self.mock_context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -423,7 +448,7 @@ def test_execute_if_cluster_exists_do_not_use(self, mock_hook): use_if_exists=False, ) with pytest.raises(AlreadyExists): - op.execute(context={}) + op.execute(context=self.mock_context) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_if_cluster_exists_in_error_state(self, mock_hook): @@ -447,7 +472,7 @@ def test_execute_if_cluster_exists_in_error_state(self, mock_hook): request_id=REQUEST_ID, ) with pytest.raises(AirflowException): - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.return_value.diagnose_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME @@ -486,7 +511,7 @@ def test_execute_if_cluster_exists_in_deleting_state( gcp_conn_id=GCP_CONN_ID, ) with pytest.raises(AirflowException): - op.execute(context={}) + op.execute(context=self.mock_context) calls = [mock.call(mock_hook.return_value), mock.call(mock_hook.return_value)] mock_get_cluster.assert_has_calls(calls) @@ -495,8 +520,90 @@ def test_execute_if_cluster_exists_in_deleting_state( region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME ) + def test_operator_extra_links(self): + op = DataprocCreateClusterOperator( + task_id=TASK_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster=CLUSTER, + delete_on_error=True, + gcp_conn_id=GCP_CONN_ID, + dag=self.dag + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag['dag']['tasks'][0]['_operator_extra_links'], + [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance( + task=op, + execution_date=DEFAULT_DATE + ) + + # Assert operator link is empty when no XCom push occured + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + '' + ) + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + '' + ) + + ti.xcom_push(key='cluster_conf', value={ + 'cluster_name': CLUSTER_NAME, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT + }) + + # Assert operator links are preserved in deserialized tasks after execution + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + + # Check negative case + self.assertEqual( + op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), + '' + ) + class TestDataprocClusterScaleOperator(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + cls.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + cls.mock_context = MagicMock() + + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() + def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) @@ -519,7 +626,7 @@ def test_execute(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.update_cluster.assert_called_once_with( @@ -531,6 +638,73 @@ def test_execute(self, mock_hook): update_mask=UPDATE_MASK, ) + def test_operator_extra_links(self): + op = DataprocScaleClusterOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + num_workers=3, + num_preemptible_workers=2, + graceful_decommission_timeout="2m", + gcp_conn_id=GCP_CONN_ID, + dag=self.dag + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag['dag']['tasks'][0]['_operator_extra_links'], + [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance( + task=op, + execution_date=DEFAULT_DATE + ) + + # Assert operator link is empty when no XCom push occured + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + '' + ) + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + '' + ) + + ti.xcom_push(key='cluster_conf', value={ + 'cluster_name': CLUSTER_NAME, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT + }) + + # Assert operator links are preserved in deserialized tasks after execution + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + + # Check negative case + self.assertEqual( + op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), + '' + ) + class TestDataprocClusterDeleteOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -745,6 +919,23 @@ def test_operator_extra_links(self, mock_hook): class TestDataprocUpdateClusterOperator(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + cls.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + cls.mock_context = MagicMock() + + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): op = DataprocUpdateClusterOperator( @@ -762,8 +953,10 @@ def test_execute(self, mock_hook): metadata=METADATA, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + op.execute(context=self.mock_context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.update_cluster.assert_called_once_with( location=GCP_LOCATION, project_id=GCP_PROJECT, @@ -777,6 +970,69 @@ def test_execute(self, mock_hook): metadata=METADATA, ) + def test_operator_extra_links(self): + op = DataprocUpdateClusterOperator( + task_id=TASK_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + project_id=GCP_PROJECT, + gcp_conn_id=GCP_CONN_ID, + dag=self.dag + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + self.assertEqual( + serialized_dag['dag']['tasks'][0]['_operator_extra_links'], + [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance( + task=op, + execution_date=DEFAULT_DATE + ) + + # Assert operator link is empty when no XCom push occured + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), '' + ) + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), '' + ) + + ti.xcom_push(key='cluster_conf', value={ + 'cluster_name': CLUSTER_NAME, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT + }) + + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED + ) + # Check for negative case + self.assertEqual( + op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), + '' + ) + class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) From 69543e50d95cdf0b018bff4c91cb72de23c1b863 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Tue, 18 Aug 2020 19:36:45 +0530 Subject: [PATCH 07/17] apache#9941 Add Baseclass for dataproc testing --- .../google/cloud/operators/test_dataproc.py | 188 +++++++----------- 1 file changed, 72 insertions(+), 116 deletions(-) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index e39f0e4f1d08f..cc6069ba02a14 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -187,12 +187,43 @@ DATAPROC_CLUSTER_LINK_EXPECTED = \ f'https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?' \ f'region={GCP_LOCATION}&project={GCP_PROJECT}' +DATAPROC_JOB_CONF_EXPECTED = { + 'job_id': TEST_JOB_ID, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT +} +DATAPROC_CLUSTER_CONF_EXPECTED = { + 'cluster_name': CLUSTER_NAME, + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT +} def assert_warning(msg: str, warnings): assert any(msg in str(w) for w in warnings) +class DataprocTestBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag( + dag_folder='/dev/null', include_examples=False + ) + cls.dag = DAG(TEST_DAG_ID, default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }) + cls.mock_ti = MagicMock() + cls.mock_context = { + 'ti': cls.mock_ti + } + + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() + + class TestsClusterGenerator(unittest.TestCase): def test_image_version(self): with pytest.raises(ValueError) as ctx: @@ -306,24 +337,7 @@ def test_build_with_custom_image_family(self): assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster -class TestDataprocClusterCreateOperator(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) - cls.mock_context = MagicMock() - - @classmethod - def tearDownClass(cls): - clear_db_runs() - clear_db_xcom() - +class TestDataprocClusterCreateOperator(DataprocTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -384,6 +398,11 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) + self.mock_ti.xcom_push.assert_called_once_with( + key='cluster_conf', + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -561,11 +580,7 @@ def test_operator_extra_links(self): '' ) - ti.xcom_push(key='cluster_conf', value={ - 'cluster_name': CLUSTER_NAME, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT - }) + ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks after execution self.assertEqual( @@ -586,24 +601,7 @@ def test_operator_extra_links(self): ) -class TestDataprocClusterScaleOperator(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) - cls.mock_context = MagicMock() - - @classmethod - def tearDownClass(cls): - clear_db_runs() - clear_db_xcom() - +class TestDataprocClusterScaleOperator(DataprocTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) @@ -637,6 +635,11 @@ def test_execute(self, mock_hook): graceful_decommission_timeout={"seconds": 600}, update_mask=UPDATE_MASK, ) + self.mock_ti.xcom_push.assert_called_once_with( + key='cluster_conf', + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None + ) def test_operator_extra_links(self): op = DataprocScaleClusterOperator( @@ -681,11 +684,7 @@ def test_operator_extra_links(self): '' ) - ti.xcom_push(key='cluster_conf', value={ - 'cluster_name': CLUSTER_NAME, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT - }) + ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks after execution self.assertEqual( @@ -735,24 +734,7 @@ def test_execute(self, mock_hook): ) -class TestDataprocSubmitJobOperator(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) - cls.mock_context = MagicMock() - - @classmethod - def tearDownClass(cls): - clear_db_runs() - clear_db_xcom() - +class TestDataprocSubmitJobOperator(DataprocTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): job = {} @@ -787,6 +769,12 @@ def test_execute(self, mock_hook): job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None ) + self.mock_ti.xcom_push.assert_called_once_with( + key='job_conf', + value=DATAPROC_JOB_CONF_EXPECTED, + execution_date=None + ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): job = {} @@ -807,7 +795,7 @@ def test_execute_async(self, mock_hook): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -844,7 +832,7 @@ def test_on_kill(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, cancel_on_kill=False, ) - op.execute(context={}) + op.execute(context=self.mock_context) op.on_kill() mock_hook.return_value.cancel_job.assert_not_called() @@ -895,11 +883,7 @@ def test_operator_extra_links(self, mock_hook): deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' ) - ti.xcom_push(key='job_conf', value={ - 'job_id': TEST_JOB_ID, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT - }) + ti.xcom_push(key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks self.assertEqual( @@ -918,24 +902,7 @@ def test_operator_extra_links(self, mock_hook): ) -class TestDataprocUpdateClusterOperator(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) - cls.mock_context = MagicMock() - - @classmethod - def tearDownClass(cls): - clear_db_runs() - clear_db_xcom() - +class TestDataprocUpdateClusterOperator(DataprocTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): op = DataprocUpdateClusterOperator( @@ -969,6 +936,11 @@ def test_execute(self, mock_hook): timeout=TIMEOUT, metadata=METADATA, ) + self.mock_ti.xcom_push.assert_called_once_with( + key='cluster_conf', + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None + ) def test_operator_extra_links(self): op = DataprocUpdateClusterOperator( @@ -1011,11 +983,7 @@ def test_operator_extra_links(self): deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), '' ) - ti.xcom_push(key='cluster_conf', value={ - 'cluster_name': CLUSTER_NAME, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT - }) + ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks self.assertEqual( @@ -1315,7 +1283,7 @@ def test_builder(self, mock_hook, mock_uuid): assert self.job == job -class TestDataProcSparkOperator(unittest.TestCase): +class TestDataProcSparkOperator(DataprocTestBase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job = { @@ -1328,22 +1296,6 @@ class TestDataProcSparkOperator(unittest.TestCase): "spark_job": {"jar_file_uris": jars, "main_class": main_class}, } - @classmethod - def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) - cls.mock_context = MagicMock() - - @classmethod - def tearDown(cls): - clear_db_runs() - clear_db_xcom() - @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_deprecation_warning(self, mock_hook): with pytest.warns(DeprecationWarning) as warnings: @@ -1358,6 +1310,7 @@ def test_execute(self, mock_hook, mock_uuid): mock_uuid.return_value = TEST_JOB_ID mock_hook.return_value.project_id = GCP_PROJECT mock_uuid.return_value = TEST_JOB_ID + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, @@ -1369,6 +1322,13 @@ def test_execute(self, mock_hook, mock_uuid): job = op.generate_job() assert self.job == job + op.execute(context=self.mock_context) + self.mock_ti.xcom_push.assert_called_once_with( + key='job_conf', + value=DATAPROC_JOB_CONF_EXPECTED, + execution_date=None + ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_operator_extra_links(self, mock_hook): mock_hook.return_value.project_id = GCP_PROJECT @@ -1410,11 +1370,7 @@ def test_operator_extra_links(self, mock_hook): deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' ) - ti.xcom_push(key='job_conf', value={ - 'job_id': TEST_JOB_ID, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT - }) + ti.xcom_push(key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED) # Assert operator links after task execution self.assertEqual( From b22bb82336a146e871bb6e46b116617c304c26e9 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Thu, 27 Aug 2020 20:19:13 +0530 Subject: [PATCH 08/17] apache#9941 Apply black to dataproc operators --- .../google/cloud/operators/dataproc.py | 279 +++++++++--------- .../google/cloud/operators/test_dataproc.py | 223 +++++--------- 2 files changed, 220 insertions(+), 282 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 18dffcae79cc3..3328ff7667cf6 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -37,49 +37,59 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance -from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder +from airflow.providers.google.cloud.hooks.dataproc import ( + DataprocHook, + DataProcJobBuilder, +) from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils import timezone from airflow.utils.decorators import apply_defaults DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc" -DATAPROC_JOB_LOG_LINK = \ - DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" -DATAPROC_CLUSTER_LINK = \ - DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?" \ - "region={region}&project={project_id}" +DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" +DATAPROC_CLUSTER_LINK = ( + DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?" "region={region}&project={project_id}" +) class DataprocJobLink(BaseOperatorLink): """ Helper class for constructing Dataproc Job link """ + name = "Dataproc Job" def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) - job_conf = ti.xcom_pull(task_ids=operator.task_id, key='job_conf') - return DATAPROC_JOB_LOG_LINK.format( - job_id=job_conf['job_id'], - region=job_conf['region'], - project_id=job_conf['project_id'] - ) if job_conf else '' + job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf") + return ( + DATAPROC_JOB_LOG_LINK.format( + job_id=job_conf["job_id"], region=job_conf["region"], project_id=job_conf["project_id"], + ) + if job_conf + else "" + ) class DataprocClusterLink(BaseOperatorLink): """ Helper class for constructing Dataproc Cluster link """ + name = "Dataproc Cluster" def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) - cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key='cluster_conf') - return DATAPROC_CLUSTER_LINK.format( - cluster_name=cluster_conf['cluster_name'], - region=cluster_conf['region'], - project_id=cluster_conf['project_id'] - ) if cluster_conf else '' + cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key="cluster_conf") + return ( + DATAPROC_CLUSTER_LINK.format( + cluster_name=cluster_conf["cluster_name"], + region=cluster_conf["region"], + project_id=cluster_conf["project_id"], + ) + if cluster_conf + else "" + ) # pylint: disable=too-many-instance-attributes @@ -213,11 +223,11 @@ def __init__( properties: Optional[Dict] = None, optional_components: Optional[List[str]] = None, num_masters: int = 1, - master_machine_type: str = 'n1-standard-4', - master_disk_type: str = 'pd-standard', + master_machine_type: str = "n1-standard-4", + master_disk_type: str = "pd-standard", master_disk_size: int = 1024, - worker_machine_type: str = 'n1-standard-4', - worker_disk_type: str = 'pd-standard', + worker_machine_type: str = "n1-standard-4", + worker_disk_type: str = "pd-standard", worker_disk_size: int = 1024, num_preemptible_workers: int = 0, service_account: Optional[str] = None, @@ -291,7 +301,7 @@ def _get_init_action_timeout(self) -> dict: def _build_gce_cluster_config(self, cluster_data): if self.zone: - zone_uri = 'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format( + zone_uri = "https://www.googleapis.com/compute/v1/projects/{}/zones/{}".format( self.project_id, self.zone ) cluster_data['gce_cluster_config']['zone_uri'] = zone_uri @@ -379,7 +389,7 @@ def _build_cluster_data(self): 'boot_disk_type': self.worker_disk_type, 'boot_disk_size_gb': self.worker_disk_size, }, - 'is_preemptible': True, + "is_preemptible": True, } if self.storage_bucket: @@ -391,8 +401,8 @@ def _build_cluster_data(self): elif self.custom_image: project_id = self.custom_image_project_id or self.project_id custom_image_url = ( - 'https://www.googleapis.com/compute/beta/projects/' - '{}/global/images/{}'.format(project_id, self.custom_image) + "https://www.googleapis.com/compute/beta/projects/" + "{}/global/images/{}".format(project_id, self.custom_image) ) cluster_data['master_config']['image_uri'] = custom_image_url if not self.single_node: @@ -423,7 +433,7 @@ def _build_cluster_data(self): if self.init_actions_uris: init_actions_dict = [ - {'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()} + {"executable_file": uri, "execution_timeout": self._get_init_action_timeout(),} for uri in self.init_actions_uris ] cluster_data['initialization_actions'] = init_actions_dict @@ -518,9 +528,7 @@ class DataprocCreateClusterOperator(BaseOperator): ) template_fields_renderers = {'cluster_config': 'json'} - operator_extra_links = ( - DataprocClusterLink(), - ) + operator_extra_links = (DataprocClusterLink(),) @apply_defaults def __init__( # pylint: disable=too-many-arguments @@ -560,8 +568,8 @@ def __init__( # pylint: disable=too-many-arguments stacklevel=1, ) # Remove result of apply defaults - if 'params' in kwargs: - del kwargs['params'] + if "params" in kwargs: + del kwargs["params"] # Create cluster object from kwargs if project_id is None: @@ -630,7 +638,9 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: gcs_uri = hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) + self.log.info( + "Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri, + ) if self.delete_on_error: self._delete_cluster(hook) raise AirflowException("Cluster was created but was in ERROR state.") @@ -743,9 +753,7 @@ class DataprocScaleClusterOperator(BaseOperator): template_fields = ['cluster_name', 'project_id', 'region', 'impersonation_chain'] - operator_extra_links = ( - DataprocClusterLink(), - ) + operator_extra_links = (DataprocClusterLink(),) @apply_defaults def __init__( @@ -753,7 +761,7 @@ def __init__( *, cluster_name: str, project_id: Optional[str] = None, - region: str = 'global', + region: str = "global", num_workers: int = 2, num_preemptible_workers: int = 0, graceful_decommission_timeout: Optional[str] = None, @@ -782,9 +790,9 @@ def __init__( def _build_scale_cluster_data(self) -> dict: scale_data = { - 'config': { - 'worker_config': {'num_instances': self.num_workers}, - 'secondary_worker_config': {'num_instances': self.num_preemptible_workers}, + "config": { + "worker_config": {"num_instances": self.num_workers}, + "secondary_worker_config": {"num_instances": self.num_preemptible_workers}, } } return scale_data @@ -816,14 +824,17 @@ def _graceful_decommission_timeout_object(self) -> Optional[Dict[str, int]]: " i.e. 1d, 4h, 10m, 30s" ) - return {'seconds': timeout} + return {"seconds": timeout} def execute(self, context) -> None: """Scale, up or down, a cluster on Google Cloud Dataproc.""" self.log.info("Scaling cluster: %s", self.cluster_name) scaling_cluster_data = self._build_scale_cluster_data() - update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] + update_mask = [ + "config.worker_config.num_instances", + "config.secondary_worker_config.num_instances", + ] hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) operation = hook.update_cluster( @@ -832,15 +843,15 @@ def execute(self, context) -> None: cluster_name=self.cluster_name, cluster=scaling_cluster_data, graceful_decommission_timeout=self._graceful_decommission_timeout_object, - update_mask={'paths': update_mask}, + update_mask={"paths": update_mask}, ) operation.result() self.log.info("Cluster scaling finished") - self.xcom_push(context, key='cluster_conf', value={ - 'cluster_name': self.cluster_name, - 'region': self.region, - 'project_id': self.project_id - }) + self.xcom_push( + context, + key="cluster_conf", + value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, + ) class DataprocDeleteClusterOperator(BaseOperator): @@ -989,20 +1000,18 @@ class DataprocJobBaseOperator(BaseOperator): job_type = "" - operator_extra_links = ( - DataprocJobLink(), - ) + operator_extra_links = (DataprocJobLink(),) @apply_defaults def __init__( self, *, - job_name: str = '{{task.task_id}}_{{ds_nodash}}', + job_name: str = "{{task.task_id}}_{{ds_nodash}}", cluster_name: str = "cluster-1", project_id: Optional[str] = None, dataproc_properties: Optional[Dict] = None, dataproc_jars: Optional[List[str]] = None, - gcp_conn_id: str = 'google_cloud_default', + gcp_conn_id: str = "google_cloud_default", delegate_to: Optional[str] = None, labels: Optional[Dict] = None, region: Optional[str] = None, @@ -1054,14 +1063,14 @@ def create_job_template(self): def _generate_job_template(self) -> str: if self.job_template: job = self.job_template.build() - return job['job'] + return job["job"] raise Exception("Create a job template before") def execute(self, context): if self.job_template: self.job = self.job_template.build() self.dataproc_job_id = self.job["job"]["reference"]["job_id"] - self.log.info('Submitting %s job %s', self.job_type, self.dataproc_job_id) + self.log.info("Submitting %s job %s", self.job_type, self.dataproc_job_id) job_object = self.hook.submit_job( project_id=self.project_id, job=self.job["job"], location=self.region ) @@ -1089,7 +1098,7 @@ def on_kill(self): """ if self.dataproc_job_id: self.hook.cancel_job( - project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region + project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region, ) @@ -1137,22 +1146,20 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] template_ext = ('.pg', '.pig') ui_color = '#0273d4' job_type = 'pig_job' - operator_extra_links = ( - DataprocJobLink(), - ) + operator_extra_links = (DataprocJobLink(),) @apply_defaults def __init__( @@ -1217,14 +1224,14 @@ class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] template_ext = ('.q', '.hql') ui_color = '#0273d4' @@ -1253,7 +1260,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException('Only one of `query` and `query_uri` can be passed.') + raise AirflowException("Only one of `query` and `query_uri` can be passed.") def generate_job(self): """ @@ -1292,18 +1299,18 @@ class DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'query', - 'variables', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "query", + "variables", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] - template_ext = ('.q',) - ui_color = '#0273d4' - job_type = 'spark_sql_job' + template_ext = (".q",) + ui_color = "#0273d4" + job_type = "spark_sql_job" @apply_defaults def __init__( @@ -1328,7 +1335,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException('Only one of `query` and `query_uri` can be passed.') + raise AirflowException("Only one of `query` and `query_uri` can be passed.") def generate_job(self): """ @@ -1374,16 +1381,16 @@ class DataprocSubmitSparkJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] - ui_color = '#0273d4' - job_type = 'spark_job' + ui_color = "#0273d4" + job_type = "spark_job" @apply_defaults def __init__( @@ -1454,16 +1461,16 @@ class DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] - ui_color = '#0273d4' - job_type = 'hadoop_job' + ui_color = "#0273d4" + job_type = "hadoop_job" @apply_defaults def __init__( @@ -1534,17 +1541,17 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator): """ template_fields = [ - 'main', - 'arguments', - 'job_name', - 'cluster_name', - 'region', - 'dataproc_jars', - 'dataproc_properties', - 'impersonation_chain', + "main", + "arguments", + "job_name", + "cluster_name", + "region", + "dataproc_jars", + "dataproc_properties", + "impersonation_chain", ] - ui_color = '#0273d4' - job_type = 'pyspark_job' + ui_color = "#0273d4" + job_type = "pyspark_job" @staticmethod def _generate_temp_filename(filename): @@ -1565,7 +1572,7 @@ def _upload_file_temp(self, bucket, local_file): GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain).upload( bucket_name=bucket, object_name=temp_filename, - mime_type='application/x-python', + mime_type="application/x-python", filename=local_file, ) return f"gs://{bucket}/{temp_filename}" @@ -1606,7 +1613,7 @@ def generate_job(self): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name, ) bucket = cluster_info['config']['config_bucket'] self.main = f"gs://{bucket}/{self.main}" @@ -1623,9 +1630,9 @@ def execute(self, context): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name, ) - bucket = cluster_info['config']['config_bucket'] + bucket = cluster_info["config"]["config_bucket"] self.main = self._upload_file_temp(bucket, self.main) self.job_template.set_python_main(self.main) @@ -1798,7 +1805,7 @@ def execute(self, context): metadata=self.metadata, ) operation.result() - self.log.info('Template instantiated.') + self.log.info("Template instantiated.") class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): @@ -1892,7 +1899,7 @@ def execute(self, context): metadata=self.metadata, ) operation.result() - self.log.info('Template instantiated.') + self.log.info("Template instantiated.") class DataprocSubmitJobOperator(BaseOperator): @@ -1944,9 +1951,7 @@ class DataprocSubmitJobOperator(BaseOperator): template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id') template_fields_renderers = {"job": "json"} - operator_extra_links = ( - DataprocJobLink(), - ) + operator_extra_links = (DataprocJobLink(),) @apply_defaults def __init__( @@ -2011,11 +2016,11 @@ def on_kill(self): if self.job_id and self.cancel_on_kill: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location) - self.xcom_push(context, key='job_conf', value={ - 'job_id': job_id, - 'region': self.location, - 'project_id': self.project_id - }) + self.xcom_push( + context, + key="job_conf", + value={"job_id": job_id, "region": self.location, "project_id": self.project_id,}, + ) class DataprocUpdateClusterOperator(BaseOperator): @@ -2124,8 +2129,12 @@ def execute(self, context: Dict): ) operation.result() self.log.info("Updated %s cluster.", self.cluster_name) - self.xcom_push(context, key='cluster_conf', value={ - 'cluster_name': self.cluster_name, - 'region': self.location, - 'project_id': self.project_id - }) + self.xcom_push( + context, + key="cluster_conf", + value={ + "cluster_name": self.cluster_name, + "region": self.location, + "project_id": self.project_id, + }, + ) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index cc6069ba02a14..93d2557707d0f 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -44,7 +44,7 @@ DataprocSubmitPySparkJobOperator, DataprocSubmitSparkJobOperator, DataprocSubmitSparkSqlJobOperator, - DataprocUpdateClusterOperator + DataprocUpdateClusterOperator, ) from airflow.serialization.serialized_objects import SerializedDAG from airflow.version import version as airflow_version @@ -179,23 +179,25 @@ } TEST_DAG_ID = 'test-dataproc-operators' DEFAULT_DATE = datetime(2020, 1, 1) -TEST_JOB_ID = 'test-job' - -DATAPROC_JOB_LINK_EXPECTED = \ - f'https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?' \ - f'region={GCP_LOCATION}&project={GCP_PROJECT}' -DATAPROC_CLUSTER_LINK_EXPECTED = \ - f'https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?' \ - f'region={GCP_LOCATION}&project={GCP_PROJECT}' +TEST_JOB_ID = "test-job" + +DATAPROC_JOB_LINK_EXPECTED = ( + f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?" + f"region={GCP_LOCATION}&project={GCP_PROJECT}" +) +DATAPROC_CLUSTER_LINK_EXPECTED = ( + f"https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?" + f"region={GCP_LOCATION}&project={GCP_PROJECT}" +) DATAPROC_JOB_CONF_EXPECTED = { - 'job_id': TEST_JOB_ID, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT + "job_id": TEST_JOB_ID, + "region": GCP_LOCATION, + "project_id": GCP_PROJECT, } DATAPROC_CLUSTER_CONF_EXPECTED = { - 'cluster_name': CLUSTER_NAME, - 'region': GCP_LOCATION, - 'project_id': GCP_PROJECT + "cluster_name": CLUSTER_NAME, + "region": GCP_LOCATION, + "project_id": GCP_PROJECT, } @@ -206,17 +208,10 @@ def assert_warning(msg: str, warnings): class DataprocTestBase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.dagbag = DagBag( - dag_folder='/dev/null', include_examples=False - ) - cls.dag = DAG(TEST_DAG_ID, default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE - }) + cls.dagbag = DagBag(dag_folder="/dev/null", include_examples=False) + cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) cls.mock_ti = MagicMock() - cls.mock_context = { - 'ti': cls.mock_ti - } + cls.mock_context = {"ti": cls.mock_ti} @classmethod def tearDownClass(cls): @@ -399,9 +394,7 @@ def test_execute(self, mock_hook, to_dict_mock): ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) self.mock_ti.xcom_push.assert_called_once_with( - key='cluster_conf', - value=DATAPROC_CLUSTER_CONF_EXPECTED, - execution_date=None + key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -547,7 +540,7 @@ def test_operator_extra_links(self): cluster=CLUSTER, delete_on_error=True, gcp_conn_id=GCP_CONN_ID, - dag=self.dag + dag=self.dag, ) serialized_dag = SerializedDAG.to_dict(self.dag) @@ -556,49 +549,38 @@ def test_operator_extra_links(self): # Assert operator links for serialized DAG self.assertEqual( - serialized_dag['dag']['tasks'][0]['_operator_extra_links'], - [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], ) # Assert operator link types are preserved during deserialization self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - ti = TaskInstance( - task=op, - execution_date=DEFAULT_DATE - ) + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) # Assert operator link is empty when no XCom push occured - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - '' - ) + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - '' + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", ) - ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks after execution self.assertEqual( deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check negative case - self.assertEqual( - op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), - '' - ) + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") class TestDataprocClusterScaleOperator(DataprocTestBase): @@ -636,9 +618,7 @@ def test_execute(self, mock_hook): update_mask=UPDATE_MASK, ) self.mock_ti.xcom_push.assert_called_once_with( - key='cluster_conf', - value=DATAPROC_CLUSTER_CONF_EXPECTED, - execution_date=None + key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) def test_operator_extra_links(self): @@ -651,7 +631,7 @@ def test_operator_extra_links(self): num_preemptible_workers=2, graceful_decommission_timeout="2m", gcp_conn_id=GCP_CONN_ID, - dag=self.dag + dag=self.dag, ) serialized_dag = SerializedDAG.to_dict(self.dag) @@ -660,49 +640,38 @@ def test_operator_extra_links(self): # Assert operator links for serialized DAG self.assertEqual( - serialized_dag['dag']['tasks'][0]['_operator_extra_links'], - [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], ) # Assert operator link types are preserved during deserialization self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - ti = TaskInstance( - task=op, - execution_date=DEFAULT_DATE - ) + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) # Assert operator link is empty when no XCom push occured - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - '' - ) + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - '' + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", ) - ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks after execution self.assertEqual( deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check negative case - self.assertEqual( - op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), - '' - ) + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") class TestDataprocClusterDeleteOperator(unittest.TestCase): @@ -770,9 +739,7 @@ def test_execute(self, mock_hook): ) self.mock_ti.xcom_push.assert_called_once_with( - key='job_conf', - value=DATAPROC_JOB_CONF_EXPECTED, - execution_date=None + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -852,7 +819,7 @@ def test_operator_extra_links(self, mock_hook): project_id=GCP_PROJECT, job={}, gcp_conn_id=GCP_CONN_ID, - dag=self.dag + dag=self.dag, ) serialized_dag = SerializedDAG.to_dict(self.dag) @@ -861,45 +828,33 @@ def test_operator_extra_links(self, mock_hook): # Assert operator links for serialized_dag self.assertEqual( - serialized_dag['dag']['tasks'][0]['_operator_extra_links'], - [{'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink': {}}] + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], ) # Assert operator link types are preserved during deserialization self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) - ti = TaskInstance( - task=op, - execution_date=DEFAULT_DATE - ) + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) # Assert operator link is empty when no XCom push occured - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' - ) + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") # Assert operator link is empty for deserialized task when no XCom push occured - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' - ) + self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") - ti.xcom_push(key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED) + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, ) # Check for negative case - self.assertEqual( - op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), - '' - ) + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "") class TestDataprocUpdateClusterOperator(DataprocTestBase): @@ -937,9 +892,7 @@ def test_execute(self, mock_hook): metadata=METADATA, ) self.mock_ti.xcom_push.assert_called_once_with( - key='cluster_conf', - value=DATAPROC_CLUSTER_CONF_EXPECTED, - execution_date=None + key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) def test_operator_extra_links(self): @@ -952,7 +905,7 @@ def test_operator_extra_links(self): graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, project_id=GCP_PROJECT, gcp_conn_id=GCP_CONN_ID, - dag=self.dag + dag=self.dag, ) serialized_dag = SerializedDAG.to_dict(self.dag) @@ -961,45 +914,36 @@ def test_operator_extra_links(self): # Assert operator links for serialized_dag self.assertEqual( - serialized_dag['dag']['tasks'][0]['_operator_extra_links'], - [{'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink': {}}] + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], ) # Assert operator link types are preserved during deserialization self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) - ti = TaskInstance( - task=op, - execution_date=DEFAULT_DATE - ) + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) # Assert operator link is empty when no XCom push occured - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), '' - ) + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), '' + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", ) - ti.xcom_push(key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) # Assert operator links are preserved in deserialized tasks self.assertEqual( deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), - DATAPROC_CLUSTER_LINK_EXPECTED + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check for negative case - self.assertEqual( - op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), - '' - ) + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase): @@ -1287,10 +1231,7 @@ class TestDataProcSparkOperator(DataprocTestBase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job = { - "reference": { - "project_id": GCP_PROJECT, - "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID, - }, + "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID,}, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "spark_job": {"jar_file_uris": jars, "main_class": main_class}, @@ -1324,9 +1265,7 @@ def test_execute(self, mock_hook, mock_uuid): op.execute(context=self.mock_context) self.mock_ti.xcom_push.assert_called_once_with( - key='job_conf', - value=DATAPROC_JOB_CONF_EXPECTED, - execution_date=None + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1339,7 +1278,7 @@ def test_operator_extra_links(self, mock_hook): gcp_conn_id=GCP_CONN_ID, main_class=self.main_class, dataproc_jars=self.jars, - dag=self.dag + dag=self.dag, ) serialized_dag = SerializedDAG.to_dict(self.dag) @@ -1348,46 +1287,36 @@ def test_operator_extra_links(self, mock_hook): # Assert operator links for serialized DAG self.assertEqual( - serialized_dag['dag']['tasks'][0]['_operator_extra_links'], - [{'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink': {}}] + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], ) # Assert operator link types are preserved during deserialization self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) - ti = TaskInstance( - task=op, - execution_date=DEFAULT_DATE - ) + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) # Assert operator link is empty when no XCom push occured - self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' - ) + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") # Assert operator link is empty for deserialized task when no XCom push occured - self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), '' - ) + self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") - ti.xcom_push(key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED) + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) # Assert operator links after task execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, ) # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), - DATAPROC_JOB_LINK_EXPECTED + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, ) # Assert for negative case self.assertEqual( - deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), - '' + deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "", ) From 6e1fb7a7598a0e8b7f3e52dae4fe59a7b813d076 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Sat, 29 Aug 2020 03:37:23 +0530 Subject: [PATCH 09/17] apache#9941 Move xcom_push before polling for job --- .../google/cloud/operators/dataproc.py | 14 ++++++------- .../google/cloud/operators/test_dataproc.py | 20 +++++++++++++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 3328ff7667cf6..4a410254ceaa3 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -1076,6 +1076,7 @@ def execute(self, context): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) + # XCom push is referenced by extra links and has to be pushed before polling for job completion self.xcom_push(context, key='job_conf', value={ 'job_id': job_id, 'region': self.region, @@ -1087,7 +1088,6 @@ def execute(self, context): self.hook.wait_for_job(job_id=job_id, location=self.region, project_id=self.project_id) self.log.info('Job %s completed successfully.', job_id) return job_id - else: raise AirflowException("Create a job template before") @@ -2001,6 +2001,12 @@ def execute(self, context: Dict): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) + # XCom job_conf is referenced by extra links and has be pushed before we poll for job completion + self.xcom_push( + context, + key="job_conf", + value={"job_id": job_id, "region": self.location, "project_id": self.project_id,}, + ) if not self.asynchronous: self.log.info('Waiting for job %s to complete', job_id) @@ -2016,12 +2022,6 @@ def on_kill(self): if self.job_id and self.cancel_on_kill: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location) - self.xcom_push( - context, - key="job_conf", - value={"job_id": job_id, "region": self.location, "project_id": self.project_id,}, - ) - class DataprocUpdateClusterOperator(BaseOperator): """ diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 93d2557707d0f..e6599784de157 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -19,7 +19,7 @@ import unittest from datetime import datetime from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import call, MagicMock, Mock import pytest from google.api_core.exceptions import AlreadyExists, NotFound @@ -212,6 +212,12 @@ def setUpClass(cls): cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) cls.mock_ti = MagicMock() cls.mock_context = {"ti": cls.mock_ti} + cls.extra_links_expected_calls = [ + call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT), + ] + cls.extra_links_manager_mock = Mock() + cls.extra_links_manager_mock.attach_mock(cls.mock_ti, 'ti') @classmethod def tearDownClass(cls): @@ -709,6 +715,7 @@ def test_execute(self, mock_hook): job = {} mock_hook.return_value.wait_for_job.return_value = None mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -724,7 +731,12 @@ def test_execute(self, mock_hook): ) op.execute(context=self.mock_context) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + # Test whether xcom push occurs before polling for job + self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, @@ -1252,6 +1264,7 @@ def test_execute(self, mock_hook, mock_uuid): mock_hook.return_value.project_id = GCP_PROJECT mock_uuid.return_value = TEST_JOB_ID mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, @@ -1268,6 +1281,9 @@ def test_execute(self, mock_hook, mock_uuid): key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None ) + # Test whether xcom push occurs before polling for job + self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_operator_extra_links(self, mock_hook): mock_hook.return_value.project_id = GCP_PROJECT From f32b10abe9f90395356f46c440c8b79302522259 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar Date: Mon, 31 Aug 2020 02:40:29 +0530 Subject: [PATCH 10/17] apache#9941 Save links data after invoking hook --- .../google/cloud/operators/dataproc.py | 46 +++--- .../google/cloud/operators/test_dataproc.py | 131 ++++++++++++------ 2 files changed, 117 insertions(+), 60 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 4a410254ceaa3..a505bfa417846 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -672,8 +672,14 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: return cluster def execute(self, context) -> dict: - self.log.info('Creating cluster: %s', self.cluster_name) - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Creating cluster: %s", self.cluster_name) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + # Save data required to display extra link no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, + ) try: # First try to create a new cluster cluster = self._create_cluster(hook) @@ -836,7 +842,13 @@ def execute(self, context) -> None: "config.secondary_worker_config.num_instances", ] - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + # Save data required to display extra link no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, + ) operation = hook.update_cluster( project_id=self.project_id, location=self.region, @@ -847,11 +859,6 @@ def execute(self, context) -> None: ) operation.result() self.log.info("Cluster scaling finished") - self.xcom_push( - context, - key="cluster_conf", - value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, - ) class DataprocDeleteClusterOperator(BaseOperator): @@ -1076,7 +1083,7 @@ def execute(self, context): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) - # XCom push is referenced by extra links and has to be pushed before polling for job completion + # Save data required for extra links no matter what the job status will be self.xcom_push(context, key='job_conf', value={ 'job_id': job_id, 'region': self.region, @@ -2001,7 +2008,7 @@ def execute(self, context: Dict): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) - # XCom job_conf is referenced by extra links and has be pushed before we poll for job completion + # Save data required by extra links no matter what the job status will be self.xcom_push( context, key="job_conf", @@ -2114,6 +2121,16 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required by extra links no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={ + "cluster_name": self.cluster_name, + "region": self.location, + "project_id": self.project_id, + }, + ) self.log.info("Updating %s cluster.", self.cluster_name) operation = hook.update_cluster( project_id=self.project_id, @@ -2129,12 +2146,3 @@ def execute(self, context: Dict): ) operation.result() self.log.info("Updated %s cluster.", self.cluster_name) - self.xcom_push( - context, - key="cluster_conf", - value={ - "cluster_name": self.cluster_name, - "region": self.location, - "project_id": self.project_id, - }, - ) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index e6599784de157..5fbd3eeaaf823 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -212,10 +212,6 @@ def setUpClass(cls): cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) cls.mock_ti = MagicMock() cls.mock_context = {"ti": cls.mock_ti} - cls.extra_links_expected_calls = [ - call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED), - call.hook().wait_for_job(job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT), - ] cls.extra_links_manager_mock = Mock() cls.extra_links_manager_mock.attach_mock(cls.mock_ti, 'ti') @@ -225,6 +221,25 @@ def tearDownClass(cls): clear_db_xcom() +class DataprocJobTestBase(DataprocTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.extra_links_expected_calls = [ + call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT), + ] + + +class DataprocClusterTestBase(DataprocTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(execution_date=None, key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) + ] + + class TestsClusterGenerator(unittest.TestCase): def test_image_version(self): with pytest.raises(ValueError) as ctx: @@ -338,7 +353,7 @@ def test_build_with_custom_image_family(self): assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster -class TestDataprocClusterCreateOperator(DataprocTestBase): +class TestDataprocClusterCreateOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -369,6 +384,23 @@ def test_deprecation_warning(self): @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook, to_dict_mock): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.create_cluster.result.return_value = None + create_cluster_args = { + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT, + 'cluster_name': CLUSTER_NAME, + 'request_id': REQUEST_ID, + 'retry': RETRY, + 'timeout': TIMEOUT, + 'metadata': METADATA, + 'cluster_config': CONFIG, + 'labels': LABELS + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().create_cluster(**create_cluster_args), + ] + op = DataprocCreateClusterOperator( task_id=TASK_ID, region=GCP_LOCATION, @@ -384,21 +416,15 @@ def test_execute(self, mock_hook, to_dict_mock): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + + # Test whether xcom push occurs before create cluster is called + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.create_cluster.assert_called_once_with( - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_config=CONFIG, - labels=LABELS, - cluster_name=CLUSTER_NAME, - request_id=REQUEST_ID, - retry=RETRY, - timeout=TIMEOUT, - metadata=METADATA, - ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) + mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) @@ -589,7 +615,7 @@ def test_operator_extra_links(self): self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") -class TestDataprocClusterScaleOperator(DataprocTestBase): +class TestDataprocClusterScaleOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) @@ -597,9 +623,22 @@ def test_deprecation_warning(self): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.update_cluster.result.return_value = None cluster_update = { "config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 4}} } + update_cluster_args = { + 'project_id': GCP_PROJECT, + 'location': GCP_LOCATION, + 'cluster_name': CLUSTER_NAME, + 'cluster': cluster_update, + 'graceful_decommission_timeout': {"seconds": 600}, + 'update_mask': UPDATE_MASK, + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().update_cluster(**update_cluster_args) + ] op = DataprocScaleClusterOperator( task_id=TASK_ID, @@ -614,15 +653,13 @@ def test_execute(self, mock_hook): ) op.execute(context=self.mock_context) - mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_hook.return_value.update_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - cluster_name=CLUSTER_NAME, - cluster=cluster_update, - graceful_decommission_timeout={"seconds": 600}, - update_mask=UPDATE_MASK, + # Test whether xcom push occurs before cluster is updated + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) @@ -709,7 +746,7 @@ def test_execute(self, mock_hook): ) -class TestDataprocSubmitJobOperator(DataprocTestBase): +class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): job = {} @@ -869,9 +906,28 @@ def test_operator_extra_links(self, mock_hook): self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "") -class TestDataprocUpdateClusterOperator(DataprocTestBase): +class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.update_cluster.result.return_value = None + cluster_decommission_timeout = {"graceful_decommission_timeout": "600s"} + update_cluster_args = { + 'location': GCP_LOCATION, + 'project_id': GCP_PROJECT, + 'cluster_name': CLUSTER_NAME, + 'cluster': CLUSTER, + 'update_mask': UPDATE_MASK, + 'graceful_decommission_timeout': cluster_decommission_timeout, + 'request_id': REQUEST_ID, + 'retry': RETRY, + 'timeout': TIMEOUT, + 'metadata': METADATA, + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().update_cluster(**update_cluster_args) + ] + op = DataprocUpdateClusterOperator( task_id=TASK_ID, location=GCP_LOCATION, @@ -879,7 +935,7 @@ def test_execute(self, mock_hook): cluster=CLUSTER, update_mask=UPDATE_MASK, request_id=REQUEST_ID, - graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + graceful_decommission_timeout=cluster_decommission_timeout, project_id=GCP_PROJECT, gcp_conn_id=GCP_CONN_ID, retry=RETRY, @@ -888,21 +944,14 @@ def test_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + + # Test whether the xcom push happens before updating the cluster + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.update_cluster.assert_called_once_with( - location=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_name=CLUSTER_NAME, - cluster=CLUSTER, - update_mask=UPDATE_MASK, - graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, - request_id=REQUEST_ID, - retry=RETRY, - timeout=TIMEOUT, - metadata=METADATA, - ) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, ) @@ -1239,7 +1288,7 @@ def test_builder(self, mock_hook, mock_uuid): assert self.job == job -class TestDataProcSparkOperator(DataprocTestBase): +class TestDataProcSparkOperator(DataprocJobTestBase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job = { From e4878fc44d412faeaa597e2393101cc4d2b92c6b Mon Sep 17 00:00:00 2001 From: Tomek Urbaszek Date: Mon, 31 Aug 2020 09:21:40 +0200 Subject: [PATCH 11/17] Update airflow/providers/google/cloud/operators/dataproc.py --- airflow/providers/google/cloud/operators/dataproc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index a505bfa417846..ae0bf239293a1 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -48,7 +48,7 @@ DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc" DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" DATAPROC_CLUSTER_LINK = ( - DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?" "region={region}&project={project_id}" + DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?region={region}&project={project_id}" ) From d960acbc208a5ab7129975650a0df1f4b7a14897 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Sun, 13 Dec 2020 19:20:22 +0530 Subject: [PATCH 12/17] apache#9941 Add tests for async execution --- airflow/providers/google/provider.yaml | 2 + .../google/cloud/operators/test_dataproc.py | 118 +++++++++++++----- 2 files changed, 86 insertions(+), 34 deletions(-) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 644f093be2297..87ba4ae1a8608 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -743,6 +743,8 @@ extra-links: - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink + - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink + - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink additional-extras: apache.beam: apache-beam[gcp] diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 5fbd3eeaaf823..4733f10a64f20 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -19,7 +19,7 @@ import unittest from datetime import datetime from unittest import mock -from unittest.mock import call, MagicMock, Mock +from unittest.mock import MagicMock, Mock, call import pytest from google.api_core.exceptions import AlreadyExists, NotFound @@ -210,10 +210,18 @@ class DataprocTestBase(unittest.TestCase): def setUpClass(cls): cls.dagbag = DagBag(dag_folder="/dev/null", include_examples=False) cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) - cls.mock_ti = MagicMock() - cls.mock_context = {"ti": cls.mock_ti} - cls.extra_links_manager_mock = Mock() - cls.extra_links_manager_mock.attach_mock(cls.mock_ti, 'ti') + + def setUp(self): + self.mock_ti = MagicMock() + self.mock_context = {"ti": self.mock_ti} + self.extra_links_manager_mock = Mock() + self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti') + + def tearDown(self): + self.mock_ti = MagicMock() + self.mock_context = {"ti": self.mock_ti} + self.extra_links_manager_mock = Mock() + self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti') @classmethod def tearDownClass(cls): @@ -395,7 +403,7 @@ def test_execute(self, mock_hook, to_dict_mock): 'timeout': TIMEOUT, 'metadata': METADATA, 'cluster_config': CONFIG, - 'labels': LABELS + 'labels': LABELS, } expected_calls = self.extra_links_expected_calls_base + [ call.hook().create_cluster(**create_cluster_args), @@ -421,12 +429,15 @@ def test_execute(self, mock_hook, to_dict_mock): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( - key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @@ -450,7 +461,8 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): ) op.execute(context=self.mock_context) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, @@ -569,7 +581,7 @@ def test_operator_extra_links(self): task_id=TASK_ID, region=GCP_LOCATION, project_id=GCP_PROJECT, - cluster=CLUSTER, + cluster_name=CLUSTER_NAME, delete_on_error=True, gcp_conn_id=GCP_CONN_ID, dag=self.dag, @@ -595,7 +607,8 @@ def test_operator_extra_links(self): # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", ) ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) @@ -608,7 +621,8 @@ def test_operator_extra_links(self): # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check negative case @@ -657,11 +671,14 @@ def test_execute(self, mock_hook): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( - key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, ) def test_operator_extra_links(self): @@ -697,7 +714,8 @@ def test_operator_extra_links(self): # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", ) ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) @@ -710,7 +728,8 @@ def test_operator_extra_links(self): # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check negative case @@ -749,6 +768,13 @@ def test_execute(self, mock_hook): class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + xcom_push_call = call.ti.xcom_push( + execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED + ) + wait_for_job_call = call.hook().wait_for_job( + job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, timeout=None + ) + job = {} mock_hook.return_value.wait_for_job.return_value = None mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID @@ -769,10 +795,15 @@ def test_execute(self, mock_hook): op.execute(context=self.mock_context) # Test whether xcom push occurs before polling for job - self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) + self.assertLess( + self.extra_links_manager_mock.mock_calls.index(xcom_push_call), + self.extra_links_manager_mock.mock_calls.index(wait_for_job_call), + msg='Xcom push for Job Link has to be done before polling for job status', + ) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -784,7 +815,7 @@ def test_execute(self, mock_hook): metadata=METADATA, ) mock_hook.return_value.wait_for_job.assert_called_once_with( - job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None + job_id=TEST_JOB_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None ) self.mock_ti.xcom_push.assert_called_once_with( @@ -794,9 +825,8 @@ def test_execute(self, mock_hook): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): job = {} - job_id = "job_id" mock_hook.return_value.wait_for_job.return_value = None - mock_hook.return_value.submit_job.return_value.reference.job_id = job_id + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -828,6 +858,10 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() + self.mock_ti.xcom_push.assert_called_once_with( + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_on_kill(self, mock_hook): job = {} @@ -896,11 +930,13 @@ def test_operator_extra_links(self, mock_hook): # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, ) # Check for negative case self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "") @@ -949,11 +985,14 @@ def test_execute(self, mock_hook): self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( - key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None, + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, ) def test_operator_extra_links(self): @@ -989,7 +1028,8 @@ def test_operator_extra_links(self): # Assert operator link is empty for deserialized task when no XCom push occured self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "", + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", ) ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) @@ -1001,7 +1041,8 @@ def test_operator_extra_links(self): ) # Assert operator links after execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), DATAPROC_CLUSTER_LINK_EXPECTED, + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, ) # Check for negative case self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") @@ -1108,7 +1149,8 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context=MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -1169,7 +1211,8 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context=MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -1236,7 +1279,8 @@ def test_execute(self, mock_hook, mock_uuid): ) op.execute(context=MagicMock()) mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -1292,7 +1336,10 @@ class TestDataProcSparkOperator(DataprocJobTestBase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] job = { - "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID,}, + "reference": { + "project_id": GCP_PROJECT, + "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID, + }, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "spark_job": {"jar_file_uris": jars, "main_class": main_class}, @@ -1371,17 +1418,20 @@ def test_operator_extra_links(self, mock_hook): # Assert operator links after task execution self.assertEqual( - op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, ) # Assert operator links are preserved in deserialized tasks self.assertEqual( - deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), DATAPROC_JOB_LINK_EXPECTED, + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, ) # Assert for negative case self.assertEqual( - deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "", + deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), + "", ) From 35512f4b03c516d613a86283122f9e175a85fc84 Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Tue, 15 Dec 2020 18:49:21 +0530 Subject: [PATCH 13/17] apache#9941 Fix tests for providers manager --- .../google/cloud/operators/dataproc.py | 76 ++++++++++++------- tests/core/test_providers_manager.py | 2 + 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index ae0bf239293a1..ebba93925b0cf 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -37,10 +37,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink from airflow.models.taskinstance import TaskInstance -from airflow.providers.google.cloud.hooks.dataproc import ( - DataprocHook, - DataProcJobBuilder, -) +from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils import timezone from airflow.utils.decorators import apply_defaults @@ -53,9 +50,7 @@ class DataprocJobLink(BaseOperatorLink): - """ - Helper class for constructing Dataproc Job link - """ + """Helper class for constructing Dataproc Job link""" name = "Dataproc Job" @@ -64,7 +59,9 @@ def get_link(self, operator, dttm): job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf") return ( DATAPROC_JOB_LOG_LINK.format( - job_id=job_conf["job_id"], region=job_conf["region"], project_id=job_conf["project_id"], + job_id=job_conf["job_id"], + region=job_conf["region"], + project_id=job_conf["project_id"], ) if job_conf else "" @@ -72,9 +69,7 @@ def get_link(self, operator, dttm): class DataprocClusterLink(BaseOperatorLink): - """ - Helper class for constructing Dataproc Cluster link - """ + """Helper class for constructing Dataproc Cluster link""" name = "Dataproc Cluster" @@ -433,7 +428,10 @@ def _build_cluster_data(self): if self.init_actions_uris: init_actions_dict = [ - {"executable_file": uri, "execution_timeout": self._get_init_action_timeout(),} + { + "executable_file": uri, + "execution_timeout": self._get_init_action_timeout(), + } for uri in self.init_actions_uris ] cluster_data['initialization_actions'] = init_actions_dict @@ -639,7 +637,9 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) self.log.info( - "Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri, + "Diagnostic information for cluster %s available at: %s", + self.cluster_name, + gcs_uri, ) if self.delete_on_error: self._delete_cluster(hook) @@ -673,12 +673,19 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: def execute(self, context) -> dict: self.log.info("Creating cluster: %s", self.cluster_name) - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, key="cluster_conf", - value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, + value={ + "cluster_name": self.cluster_name, + "region": self.region, + "project_id": self.project_id, + }, ) try: # First try to create a new cluster @@ -842,12 +849,19 @@ def execute(self, context) -> None: "config.secondary_worker_config.num_instances", ] - hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain,) + hook = DataprocHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, key="cluster_conf", - value={"cluster_name": self.cluster_name, "region": self.region, "project_id": self.project_id,}, + value={ + "cluster_name": self.cluster_name, + "region": self.region, + "project_id": self.project_id, + }, ) operation = hook.update_cluster( project_id=self.project_id, @@ -1084,11 +1098,11 @@ def execute(self, context): job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) # Save data required for extra links no matter what the job status will be - self.xcom_push(context, key='job_conf', value={ - 'job_id': job_id, - 'region': self.region, - 'project_id': self.project_id - }) + self.xcom_push( + context, + key='job_conf', + value={'job_id': job_id, 'region': self.region, 'project_id': self.project_id}, + ) if not self.asynchronous: self.log.info('Waiting for job %s to complete', job_id) @@ -1105,7 +1119,9 @@ def on_kill(self): """ if self.dataproc_job_id: self.hook.cancel_job( - project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region, + project_id=self.project_id, + job_id=self.dataproc_job_id, + location=self.region, ) @@ -1620,7 +1636,9 @@ def generate_job(self): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name, + project_id=self.hook.project_id, + region=self.region, + cluster_name=self.cluster_name, ) bucket = cluster_info['config']['config_bucket'] self.main = f"gs://{bucket}/{self.main}" @@ -1637,7 +1655,9 @@ def execute(self, context): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name, + project_id=self.hook.project_id, + region=self.region, + cluster_name=self.cluster_name, ) bucket = cluster_info["config"]["config_bucket"] self.main = self._upload_file_temp(bucket, self.main) @@ -2012,7 +2032,11 @@ def execute(self, context: Dict): self.xcom_push( context, key="job_conf", - value={"job_id": job_id, "region": self.location, "project_id": self.project_id,}, + value={ + "job_id": job_id, + "region": self.location, + "project_id": self.project_id, + }, ) if not self.asynchronous: diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index 57b581b366c7c..e6f26ce5e16c8 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -227,6 +227,8 @@ EXTRA_LINKS = [ 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink', 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink', + 'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink', + 'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink', 'airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink', 'airflow.providers.qubole.operators.qubole.QDSLink', ] From 2fe9e58975542fa2ac1fae03f6e0fd3e1abbe1bb Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Thu, 17 Dec 2020 20:54:34 +0530 Subject: [PATCH 14/17] apache#9941 Fix static checks --- airflow/serialization/serialized_objects.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index c5ed0cad83216..645844af7573b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -78,9 +78,10 @@ "airflow.providers.google.cloud.operators.dataproc.DataprocJobLink", "airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink", "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink", - "airflow.providers.qubole.operators.qubole.QDSLink" + "airflow.providers.qubole.operators.qubole.QDSLink", ] + @cache def get_operator_extra_links(): """ From 0ef1dad83fd9e970e158121424d0e25de417d85c Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Thu, 17 Dec 2020 21:52:49 +0530 Subject: [PATCH 15/17] apache#9941 Include dataproc links in discovery --- scripts/in_container/run_install_and_test_provider_packages.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh index 6dd19084f6843..9dd85407d3fc0 100755 --- a/scripts/in_container/run_install_and_test_provider_packages.sh +++ b/scripts/in_container/run_install_and_test_provider_packages.sh @@ -137,7 +137,7 @@ function discover_all_extra_links() { group_start "Listing available extra links via 'airflow providers links'" COLUMNS=180 airflow providers links - local expected_number_of_extra_links=4 + local expected_number_of_extra_links=6 local actual_number_of_extra_links actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs) if [[ ${actual_number_of_extra_links} != "${expected_number_of_extra_links}" ]]; then From 33820d0bcaee3f5d7facaf8f561ae6ba827e07ea Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Tue, 4 May 2021 01:51:43 +0530 Subject: [PATCH 16/17] apache#9941 Remove unrelated quote changes --- .../google/cloud/operators/dataproc.py | 179 +++++++++--------- 1 file changed, 88 insertions(+), 91 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index ebba93925b0cf..d8bcebe0dbe0b 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -218,11 +218,11 @@ def __init__( properties: Optional[Dict] = None, optional_components: Optional[List[str]] = None, num_masters: int = 1, - master_machine_type: str = "n1-standard-4", - master_disk_type: str = "pd-standard", + master_machine_type: str = 'n1-standard-4', + master_disk_type: str = 'pd-standard', master_disk_size: int = 1024, - worker_machine_type: str = "n1-standard-4", - worker_disk_type: str = "pd-standard", + worker_machine_type: str = 'n1-standard-4', + worker_disk_type: str = 'pd-standard', worker_disk_size: int = 1024, num_preemptible_workers: int = 0, service_account: Optional[str] = None, @@ -296,7 +296,7 @@ def _get_init_action_timeout(self) -> dict: def _build_gce_cluster_config(self, cluster_data): if self.zone: - zone_uri = "https://www.googleapis.com/compute/v1/projects/{}/zones/{}".format( + zone_uri = 'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format( self.project_id, self.zone ) cluster_data['gce_cluster_config']['zone_uri'] = zone_uri @@ -384,7 +384,7 @@ def _build_cluster_data(self): 'boot_disk_type': self.worker_disk_type, 'boot_disk_size_gb': self.worker_disk_size, }, - "is_preemptible": True, + 'is_preemptible': True, } if self.storage_bucket: @@ -396,8 +396,8 @@ def _build_cluster_data(self): elif self.custom_image: project_id = self.custom_image_project_id or self.project_id custom_image_url = ( - "https://www.googleapis.com/compute/beta/projects/" - "{}/global/images/{}".format(project_id, self.custom_image) + 'https://www.googleapis.com/compute/beta/projects/' + '{}/global/images/{}'.format(project_id, self.custom_image) ) cluster_data['master_config']['image_uri'] = custom_image_url if not self.single_node: @@ -566,8 +566,8 @@ def __init__( # pylint: disable=too-many-arguments stacklevel=1, ) # Remove result of apply defaults - if "params" in kwargs: - del kwargs["params"] + if 'params' in kwargs: + del kwargs['params'] # Create cluster object from kwargs if project_id is None: @@ -709,11 +709,11 @@ def execute(self, context) -> dict: cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) - self.xcom_push(context, key='cluster_conf', value={ - 'cluster_name': self.cluster_name, - 'region': self.region, - 'project_id': self.project_id - }) + self.xcom_push( + context, + key='cluster_conf', + value={'cluster_name': self.cluster_name, 'region': self.region, 'project_id': self.project_id}, + ) return Cluster.to_dict(cluster) @@ -774,7 +774,7 @@ def __init__( *, cluster_name: str, project_id: Optional[str] = None, - region: str = "global", + region: str = 'global', num_workers: int = 2, num_preemptible_workers: int = 0, graceful_decommission_timeout: Optional[str] = None, @@ -803,9 +803,9 @@ def __init__( def _build_scale_cluster_data(self) -> dict: scale_data = { - "config": { - "worker_config": {"num_instances": self.num_workers}, - "secondary_worker_config": {"num_instances": self.num_preemptible_workers}, + 'config': { + 'worker_config': {'num_instances': self.num_workers}, + 'secondary_worker_config': {'num_instances': self.num_preemptible_workers}, } } return scale_data @@ -837,7 +837,7 @@ def _graceful_decommission_timeout_object(self) -> Optional[Dict[str, int]]: " i.e. 1d, 4h, 10m, 30s" ) - return {"seconds": timeout} + return {'seconds': timeout} def execute(self, context) -> None: """Scale, up or down, a cluster on Google Cloud Dataproc.""" @@ -869,7 +869,7 @@ def execute(self, context) -> None: cluster_name=self.cluster_name, cluster=scaling_cluster_data, graceful_decommission_timeout=self._graceful_decommission_timeout_object, - update_mask={"paths": update_mask}, + update_mask={'paths': update_mask}, ) operation.result() self.log.info("Cluster scaling finished") @@ -1027,12 +1027,12 @@ class DataprocJobBaseOperator(BaseOperator): def __init__( self, *, - job_name: str = "{{task.task_id}}_{{ds_nodash}}", + job_name: str = '{{task.task_id}}_{{ds_nodash}}', cluster_name: str = "cluster-1", project_id: Optional[str] = None, dataproc_properties: Optional[Dict] = None, dataproc_jars: Optional[List[str]] = None, - gcp_conn_id: str = "google_cloud_default", + gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, labels: Optional[Dict] = None, region: Optional[str] = None, @@ -1084,14 +1084,14 @@ def create_job_template(self): def _generate_job_template(self) -> str: if self.job_template: job = self.job_template.build() - return job["job"] + return job['job'] raise Exception("Create a job template before") def execute(self, context): if self.job_template: self.job = self.job_template.build() self.dataproc_job_id = self.job["job"]["reference"]["job_id"] - self.log.info("Submitting %s job %s", self.job_type, self.dataproc_job_id) + self.log.info('Submitting %s job %s', self.job_type, self.dataproc_job_id) job_object = self.hook.submit_job( project_id=self.project_id, job=self.job["job"], location=self.region ) @@ -1112,7 +1112,7 @@ def execute(self, context): else: raise AirflowException("Create a job template before") - def on_kill(self): + def on_kill(self) -> None: """ Callback called when the operator is killed. Cancel any running job. @@ -1169,14 +1169,14 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "query", - "variables", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] template_ext = ('.pg', '.pig') ui_color = '#0273d4' @@ -1231,7 +1231,6 @@ def execute(self, context): self.job_template.add_variables(self.variables) super().execute(context) - # context['task_instance'].xcom_push(key='job_id', value=self.job['reference']['job_id']) class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): @@ -1247,14 +1246,14 @@ class DataprocSubmitHiveJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "query", - "variables", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] template_ext = ('.q', '.hql') ui_color = '#0273d4' @@ -1283,7 +1282,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException("Only one of `query` and `query_uri` can be passed.") + raise AirflowException('Only one of `query` and `query_uri` can be passed.') def generate_job(self): """ @@ -1322,18 +1321,18 @@ class DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "query", - "variables", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'query', + 'variables', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] - template_ext = (".q",) - ui_color = "#0273d4" - job_type = "spark_sql_job" + template_ext = ('.q',) + ui_color = '#0273d4' + job_type = 'spark_sql_job' @apply_defaults def __init__( @@ -1358,7 +1357,7 @@ def __init__( self.query_uri = query_uri self.variables = variables if self.query is not None and self.query_uri is not None: - raise AirflowException("Only one of `query` and `query_uri` can be passed.") + raise AirflowException('Only one of `query` and `query_uri` can be passed.') def generate_job(self): """ @@ -1404,16 +1403,16 @@ class DataprocSubmitSparkJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "arguments", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] - ui_color = "#0273d4" - job_type = "spark_job" + ui_color = '#0273d4' + job_type = 'spark_job' @apply_defaults def __init__( @@ -1484,16 +1483,16 @@ class DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "arguments", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] - ui_color = "#0273d4" - job_type = "hadoop_job" + ui_color = '#0273d4' + job_type = 'hadoop_job' @apply_defaults def __init__( @@ -1564,17 +1563,17 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator): """ template_fields = [ - "main", - "arguments", - "job_name", - "cluster_name", - "region", - "dataproc_jars", - "dataproc_properties", - "impersonation_chain", + 'main', + 'arguments', + 'job_name', + 'cluster_name', + 'region', + 'dataproc_jars', + 'dataproc_properties', + 'impersonation_chain', ] - ui_color = "#0273d4" - job_type = "pyspark_job" + ui_color = '#0273d4' + job_type = 'pyspark_job' @staticmethod def _generate_temp_filename(filename): @@ -1595,7 +1594,7 @@ def _upload_file_temp(self, bucket, local_file): GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain).upload( bucket_name=bucket, object_name=temp_filename, - mime_type="application/x-python", + mime_type='application/x-python', filename=local_file, ) return f"gs://{bucket}/{temp_filename}" @@ -1659,7 +1658,7 @@ def execute(self, context): region=self.region, cluster_name=self.cluster_name, ) - bucket = cluster_info["config"]["config_bucket"] + bucket = cluster_info['config']['config_bucket'] self.main = self._upload_file_temp(bucket, self.main) self.job_template.set_python_main(self.main) @@ -1832,7 +1831,7 @@ def execute(self, context): metadata=self.metadata, ) operation.result() - self.log.info("Template instantiated.") + self.log.info('Template instantiated.') class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator): @@ -1926,7 +1925,7 @@ def execute(self, context): metadata=self.metadata, ) operation.result() - self.log.info("Template instantiated.") + self.log.info('Template instantiated.') class DataprocSubmitJobOperator(BaseOperator): @@ -2107,9 +2106,7 @@ class DataprocUpdateClusterOperator(BaseOperator): """ template_fields = ('impersonation_chain', 'cluster_name') - operator_extra_links = ( - DataprocClusterLink(), - ) + operator_extra_links = (DataprocClusterLink(),) @apply_defaults def __init__( # pylint: disable=too-many-arguments From 9991a5442e3c4fe09b661deb55dd11cbaa1e10fa Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Thu, 6 May 2021 01:52:57 +0530 Subject: [PATCH 17/17] apache#9941 Fix tests --- .../google/cloud/operators/dataproc.py | 47 ++++-------------- .../google/cloud/operators/test_dataproc.py | 49 +++++-------------- 2 files changed, 23 insertions(+), 73 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index d8bcebe0dbe0b..d8df03a9505f0 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -428,10 +428,7 @@ def _build_cluster_data(self): if self.init_actions_uris: init_actions_dict = [ - { - "executable_file": uri, - "execution_timeout": self._get_init_action_timeout(), - } + {'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()} for uri in self.init_actions_uris ] cluster_data['initialization_actions'] = init_actions_dict @@ -636,11 +633,7 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: gcs_uri = hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - self.log.info( - "Diagnostic information for cluster %s available at: %s", - self.cluster_name, - gcs_uri, - ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) if self.delete_on_error: self._delete_cluster(hook) raise AirflowException("Cluster was created but was in ERROR state.") @@ -672,11 +665,8 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: return cluster def execute(self, context) -> dict: - self.log.info("Creating cluster: %s", self.cluster_name) - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + self.log.info('Creating cluster: %s', self.cluster_name) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, @@ -709,11 +699,6 @@ def execute(self, context) -> dict: cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) - self.xcom_push( - context, - key='cluster_conf', - value={'cluster_name': self.cluster_name, 'region': self.region, 'project_id': self.project_id}, - ) return Cluster.to_dict(cluster) @@ -844,15 +829,9 @@ def execute(self, context) -> None: self.log.info("Scaling cluster: %s", self.cluster_name) scaling_cluster_data = self._build_scale_cluster_data() - update_mask = [ - "config.worker_config.num_instances", - "config.secondary_worker_config.num_instances", - ] - - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] + + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, @@ -1119,9 +1098,7 @@ def on_kill(self) -> None: """ if self.dataproc_job_id: self.hook.cancel_job( - project_id=self.project_id, - job_id=self.dataproc_job_id, - location=self.region, + project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region ) @@ -1635,9 +1612,7 @@ def generate_job(self): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name, + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = f"gs://{bucket}/{self.main}" @@ -1654,9 +1629,7 @@ def execute(self, context): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name, + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = self._upload_file_temp(bucket, self.main) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 4733f10a64f20..9a0ef21b7d327 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -424,16 +424,13 @@ def test_execute(self, mock_hook, to_dict_mock): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) # Test whether xcom push occurs before create cluster is called self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) - mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -460,10 +457,7 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -666,15 +660,12 @@ def test_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) # Test whether xcom push occurs before cluster is updated self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -794,6 +785,8 @@ def test_execute(self, mock_hook): ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + # Test whether xcom push occurs before polling for job self.assertLess( self.extra_links_manager_mock.mock_calls.index(xcom_push_call), @@ -801,10 +794,6 @@ def test_execute(self, mock_hook): msg='Xcom push for Job Link has to be done before polling for job status', ) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, @@ -980,15 +969,12 @@ def test_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) # Test whether the xcom push happens before updating the cluster self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -1148,10 +1134,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1210,10 +1193,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1278,10 +1258,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1306,7 +1283,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id="other-project", job=self.other_project_job, location=GCP_LOCATION