diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index c8ee5e35a69f1..d8df03a9505f0 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -35,12 +35,57 @@ from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models import BaseOperator, BaseOperatorLink +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils import timezone from airflow.utils.decorators import apply_defaults +DATAPROC_BASE_LINK = "https://console.cloud.google.com/dataproc" +DATAPROC_JOB_LOG_LINK = DATAPROC_BASE_LINK + "/jobs/{job_id}?region={region}&project={project_id}" +DATAPROC_CLUSTER_LINK = ( + DATAPROC_BASE_LINK + "/clusters/{cluster_name}/monitoring?region={region}&project={project_id}" +) + + +class DataprocJobLink(BaseOperatorLink): + """Helper class for constructing Dataproc Job link""" + + name = "Dataproc Job" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf") + return ( + DATAPROC_JOB_LOG_LINK.format( + job_id=job_conf["job_id"], + region=job_conf["region"], + project_id=job_conf["project_id"], + ) + if job_conf + else "" + ) + + +class DataprocClusterLink(BaseOperatorLink): + """Helper class for constructing Dataproc Cluster link""" + + name = "Dataproc Cluster" + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key="cluster_conf") + return ( + DATAPROC_CLUSTER_LINK.format( + cluster_name=cluster_conf["cluster_name"], + region=cluster_conf["region"], + project_id=cluster_conf["project_id"], + ) + if cluster_conf + else "" + ) + # pylint: disable=too-many-instance-attributes class ClusterGenerator: @@ -478,6 +523,8 @@ class DataprocCreateClusterOperator(BaseOperator): ) template_fields_renderers = {'cluster_config': 'json'} + operator_extra_links = (DataprocClusterLink(),) + @apply_defaults def __init__( # pylint: disable=too-many-arguments self, @@ -620,6 +667,16 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: def execute(self, context) -> dict: self.log.info('Creating cluster: %s', self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={ + "cluster_name": self.cluster_name, + "region": self.region, + "project_id": self.project_id, + }, + ) try: # First try to create a new cluster cluster = self._create_cluster(hook) @@ -694,6 +751,8 @@ class DataprocScaleClusterOperator(BaseOperator): template_fields = ['cluster_name', 'project_id', 'region', 'impersonation_chain'] + operator_extra_links = (DataprocClusterLink(),) + @apply_defaults def __init__( self, @@ -773,6 +832,16 @@ def execute(self, context) -> None: update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={ + "cluster_name": self.cluster_name, + "region": self.region, + "project_id": self.project_id, + }, + ) operation = hook.update_cluster( project_id=self.project_id, location=self.region, @@ -931,6 +1000,8 @@ class DataprocJobBaseOperator(BaseOperator): job_type = "" + operator_extra_links = (DataprocJobLink(),) + @apply_defaults def __init__( self, @@ -1005,6 +1076,12 @@ def execute(self, context): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) + # Save data required for extra links no matter what the job status will be + self.xcom_push( + context, + key='job_conf', + value={'job_id': job_id, 'region': self.region, 'project_id': self.project_id}, + ) if not self.asynchronous: self.log.info('Waiting for job %s to complete', job_id) @@ -1082,6 +1159,8 @@ class DataprocSubmitPigJobOperator(DataprocJobBaseOperator): ui_color = '#0273d4' job_type = 'pig_job' + operator_extra_links = (DataprocJobLink(),) + @apply_defaults def __init__( self, @@ -1871,6 +1950,8 @@ class DataprocSubmitJobOperator(BaseOperator): template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id') template_fields_renderers = {"job": "json"} + operator_extra_links = (DataprocJobLink(),) + @apply_defaults def __init__( self, @@ -1919,6 +2000,16 @@ def execute(self, context: Dict): ) job_id = job_object.reference.job_id self.log.info('Job %s submitted successfully.', job_id) + # Save data required by extra links no matter what the job status will be + self.xcom_push( + context, + key="job_conf", + value={ + "job_id": job_id, + "region": self.location, + "project_id": self.project_id, + }, + ) if not self.asynchronous: self.log.info('Waiting for job %s to complete', job_id) @@ -1988,6 +2079,7 @@ class DataprocUpdateClusterOperator(BaseOperator): """ template_fields = ('impersonation_chain', 'cluster_name') + operator_extra_links = (DataprocClusterLink(),) @apply_defaults def __init__( # pylint: disable=too-many-arguments @@ -2023,6 +2115,16 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context: Dict): hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required by extra links no matter what the cluster status will be + self.xcom_push( + context, + key="cluster_conf", + value={ + "cluster_name": self.cluster_name, + "region": self.location, + "project_id": self.project_id, + }, + ) self.log.info("Updating %s cluster.", self.cluster_name) operation = hook.update_cluster( project_id=self.project_id, diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 644f093be2297..87ba4ae1a8608 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -743,6 +743,8 @@ extra-links: - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink - airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink - airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink + - airflow.providers.google.cloud.operators.dataproc.DataprocJobLink + - airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink additional-extras: apache.beam: apache-beam[gcp] diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8a6fdc89d2062..645844af7573b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -72,6 +72,15 @@ "airflow.sensors.external_task_sensor.ExternalTaskSensorLink", } +BUILTIN_OPERATOR_EXTRA_LINKS: List[str] = [ + "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink", + "airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink", + "airflow.providers.google.cloud.operators.dataproc.DataprocJobLink", + "airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink", + "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink", + "airflow.providers.qubole.operators.qubole.QDSLink", +] + @cache def get_operator_extra_links(): diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh index 6dd19084f6843..9dd85407d3fc0 100755 --- a/scripts/in_container/run_install_and_test_provider_packages.sh +++ b/scripts/in_container/run_install_and_test_provider_packages.sh @@ -137,7 +137,7 @@ function discover_all_extra_links() { group_start "Listing available extra links via 'airflow providers links'" COLUMNS=180 airflow providers links - local expected_number_of_extra_links=4 + local expected_number_of_extra_links=6 local actual_number_of_extra_links actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs) if [[ ${actual_number_of_extra_links} != "${expected_number_of_extra_links}" ]]; then diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index 57b581b366c7c..e6f26ce5e16c8 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -227,6 +227,8 @@ EXTRA_LINKS = [ 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink', 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink', + 'airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink', + 'airflow.providers.google.cloud.operators.dataproc.DataprocJobLink', 'airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink', 'airflow.providers.qubole.operators.qubole.QDSLink', ] diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index fb2ceef1f98e7..9a0ef21b7d327 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -19,19 +19,23 @@ import unittest from datetime import datetime from unittest import mock +from unittest.mock import MagicMock, Mock, call import pytest from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry from airflow import AirflowException +from airflow.models import DAG, DagBag, TaskInstance from airflow.providers.google.cloud.operators.dataproc import ( ClusterGenerator, + DataprocClusterLink, DataprocCreateClusterOperator, DataprocCreateWorkflowTemplateOperator, DataprocDeleteClusterOperator, DataprocInstantiateInlineWorkflowTemplateOperator, DataprocInstantiateWorkflowTemplateOperator, + DataprocJobLink, DataprocScaleClusterOperator, DataprocSubmitHadoopJobOperator, DataprocSubmitHiveJobOperator, @@ -42,7 +46,9 @@ DataprocSubmitSparkSqlJobOperator, DataprocUpdateClusterOperator, ) +from airflow.serialization.serialized_objects import SerializedDAG from airflow.version import version as airflow_version +from tests.test_utils.db import clear_db_runs, clear_db_xcom cluster_params = inspect.signature(ClusterGenerator.__init__).parameters @@ -171,12 +177,77 @@ }, "jobs": [{"step_id": "pig_job_1", "pig_job": {}}], } +TEST_DAG_ID = 'test-dataproc-operators' +DEFAULT_DATE = datetime(2020, 1, 1) +TEST_JOB_ID = "test-job" + +DATAPROC_JOB_LINK_EXPECTED = ( + f"https://console.cloud.google.com/dataproc/jobs/{TEST_JOB_ID}?" + f"region={GCP_LOCATION}&project={GCP_PROJECT}" +) +DATAPROC_CLUSTER_LINK_EXPECTED = ( + f"https://console.cloud.google.com/dataproc/clusters/{CLUSTER_NAME}/monitoring?" + f"region={GCP_LOCATION}&project={GCP_PROJECT}" +) +DATAPROC_JOB_CONF_EXPECTED = { + "job_id": TEST_JOB_ID, + "region": GCP_LOCATION, + "project_id": GCP_PROJECT, +} +DATAPROC_CLUSTER_CONF_EXPECTED = { + "cluster_name": CLUSTER_NAME, + "region": GCP_LOCATION, + "project_id": GCP_PROJECT, +} def assert_warning(msg: str, warnings): assert any(msg in str(w) for w in warnings) +class DataprocTestBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dagbag = DagBag(dag_folder="/dev/null", include_examples=False) + cls.dag = DAG(TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) + + def setUp(self): + self.mock_ti = MagicMock() + self.mock_context = {"ti": self.mock_ti} + self.extra_links_manager_mock = Mock() + self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti') + + def tearDown(self): + self.mock_ti = MagicMock() + self.mock_context = {"ti": self.mock_ti} + self.extra_links_manager_mock = Mock() + self.extra_links_manager_mock.attach_mock(self.mock_ti, 'ti') + + @classmethod + def tearDownClass(cls): + clear_db_runs() + clear_db_xcom() + + +class DataprocJobTestBase(DataprocTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.extra_links_expected_calls = [ + call.ti.xcom_push(execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED), + call.hook().wait_for_job(job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT), + ] + + +class DataprocClusterTestBase(DataprocTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.extra_links_expected_calls_base = [ + call.ti.xcom_push(execution_date=None, key='cluster_conf', value=DATAPROC_CLUSTER_CONF_EXPECTED) + ] + + class TestsClusterGenerator(unittest.TestCase): def test_image_version(self): with pytest.raises(ValueError) as ctx: @@ -290,7 +361,7 @@ def test_build_with_custom_image_family(self): assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster -class TestDataprocClusterCreateOperator(unittest.TestCase): +class TestDataprocClusterCreateOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -321,6 +392,23 @@ def test_deprecation_warning(self): @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook, to_dict_mock): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.create_cluster.result.return_value = None + create_cluster_args = { + 'region': GCP_LOCATION, + 'project_id': GCP_PROJECT, + 'cluster_name': CLUSTER_NAME, + 'request_id': REQUEST_ID, + 'retry': RETRY, + 'timeout': TIMEOUT, + 'metadata': METADATA, + 'cluster_config': CONFIG, + 'labels': LABELS, + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().create_cluster(**create_cluster_args), + ] + op = DataprocCreateClusterOperator( task_id=TASK_ID, region=GCP_LOCATION, @@ -335,20 +423,19 @@ def test_execute(self, mock_hook, to_dict_mock): metadata=METADATA, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_hook.return_value.create_cluster.assert_called_once_with( - region=GCP_LOCATION, - project_id=GCP_PROJECT, - cluster_config=CONFIG, - labels=LABELS, - cluster_name=CLUSTER_NAME, - request_id=REQUEST_ID, - retry=RETRY, - timeout=TIMEOUT, - metadata=METADATA, - ) + mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) + + # Test whether xcom push occurs before create cluster is called + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) + self.mock_ti.xcom_push.assert_called_once_with( + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, + ) @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -369,7 +456,7 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, @@ -411,7 +498,7 @@ def test_execute_if_cluster_exists_do_not_use(self, mock_hook): use_if_exists=False, ) with pytest.raises(AlreadyExists): - op.execute(context={}) + op.execute(context=self.mock_context) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_if_cluster_exists_in_error_state(self, mock_hook): @@ -435,7 +522,7 @@ def test_execute_if_cluster_exists_in_error_state(self, mock_hook): request_id=REQUEST_ID, ) with pytest.raises(AirflowException): - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.return_value.diagnose_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME @@ -474,7 +561,7 @@ def test_execute_if_cluster_exists_in_deleting_state( gcp_conn_id=GCP_CONN_ID, ) with pytest.raises(AirflowException): - op.execute(context={}) + op.execute(context=self.mock_context) calls = [mock.call(mock_hook.return_value), mock.call(mock_hook.return_value)] mock_get_cluster.assert_has_calls(calls) @@ -483,8 +570,60 @@ def test_execute_if_cluster_exists_in_deleting_state( region=GCP_LOCATION, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME ) + def test_operator_extra_links(self): + op = DataprocCreateClusterOperator( + task_id=TASK_ID, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + delete_on_error=True, + gcp_conn_id=GCP_CONN_ID, + dag=self.dag, + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + + # Assert operator link is empty when no XCom push occured + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", + ) + + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + + # Assert operator links are preserved in deserialized tasks after execution + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + + # Check negative case + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") + -class TestDataprocClusterScaleOperator(unittest.TestCase): +class TestDataprocClusterScaleOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(DeprecationWarning) as warnings: DataprocScaleClusterOperator(task_id=TASK_ID, cluster_name=CLUSTER_NAME, project_id=GCP_PROJECT) @@ -492,9 +631,22 @@ def test_deprecation_warning(self): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.update_cluster.result.return_value = None cluster_update = { "config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 4}} } + update_cluster_args = { + 'project_id': GCP_PROJECT, + 'location': GCP_LOCATION, + 'cluster_name': CLUSTER_NAME, + 'cluster': cluster_update, + 'graceful_decommission_timeout': {"seconds": 600}, + 'update_mask': UPDATE_MASK, + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().update_cluster(**update_cluster_args) + ] op = DataprocScaleClusterOperator( task_id=TASK_ID, @@ -507,18 +659,73 @@ def test_execute(self, mock_hook): gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) - + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_hook.return_value.update_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) + + # Test whether xcom push occurs before cluster is updated + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + + self.mock_ti.xcom_push.assert_called_once_with( + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, + ) + + def test_operator_extra_links(self): + op = DataprocScaleClusterOperator( + task_id=TASK_ID, cluster_name=CLUSTER_NAME, - cluster=cluster_update, - graceful_decommission_timeout={"seconds": 600}, - update_mask=UPDATE_MASK, + project_id=GCP_PROJECT, + region=GCP_LOCATION, + num_workers=3, + num_preemptible_workers=2, + graceful_decommission_timeout="2m", + gcp_conn_id=GCP_CONN_ID, + dag=self.dag, ) + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + + # Assert operator link is empty when no XCom push occured + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", + ) + + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + + # Assert operator links are preserved in deserialized tasks after execution + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + + # Check negative case + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") + class TestDataprocClusterDeleteOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -549,13 +756,20 @@ def test_execute(self, mock_hook): ) -class TestDataprocSubmitJobOperator(unittest.TestCase): +class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + xcom_push_call = call.ti.xcom_push( + execution_date=None, key='job_conf', value=DATAPROC_JOB_CONF_EXPECTED + ) + wait_for_job_call = call.hook().wait_for_job( + job_id=TEST_JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, timeout=None + ) + job = {} - job_id = "job_id" mock_hook.return_value.wait_for_job.return_value = None - mock_hook.return_value.submit_job.return_value.reference.job_id = job_id + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -569,9 +783,17 @@ def test_execute(self, mock_hook): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + + # Test whether xcom push occurs before polling for job + self.assertLess( + self.extra_links_manager_mock.mock_calls.index(xcom_push_call), + self.extra_links_manager_mock.mock_calls.index(wait_for_job_call), + msg='Xcom push for Job Link has to be done before polling for job status', + ) + mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, @@ -582,15 +804,18 @@ def test_execute(self, mock_hook): metadata=METADATA, ) mock_hook.return_value.wait_for_job.assert_called_once_with( - job_id=job_id, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None + job_id=TEST_JOB_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, timeout=None + ) + + self.mock_ti.xcom_push.assert_called_once_with( + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None ) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_async(self, mock_hook): job = {} - job_id = "job_id" mock_hook.return_value.wait_for_job.return_value = None - mock_hook.return_value.submit_job.return_value.reference.job_id = job_id + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -605,7 +830,7 @@ def test_execute_async(self, mock_hook): request_id=REQUEST_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -622,6 +847,10 @@ def test_execute_async(self, mock_hook): ) mock_hook.return_value.wait_for_job.assert_not_called() + self.mock_ti.xcom_push.assert_called_once_with( + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_on_kill(self, mock_hook): job = {} @@ -642,7 +871,7 @@ def test_on_kill(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, cancel_on_kill=False, ) - op.execute(context={}) + op.execute(context=self.mock_context) op.on_kill() mock_hook.return_value.cancel_job.assert_not_called() @@ -653,10 +882,77 @@ def test_on_kill(self, mock_hook): project_id=GCP_PROJECT, location=GCP_LOCATION, job_id=job_id ) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_operator_extra_links(self, mock_hook): + mock_hook.return_value.project_id = GCP_PROJECT + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + location=GCP_LOCATION, + project_id=GCP_PROJECT, + job={}, + gcp_conn_id=GCP_CONN_ID, + dag=self.dag, + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + + # Assert operator link is empty when no XCom push occured + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) -class TestDataprocUpdateClusterOperator(unittest.TestCase): + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, + ) + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, + ) + # Check for negative case + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), "") + + +class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook): + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') + mock_hook.return_value.update_cluster.result.return_value = None + cluster_decommission_timeout = {"graceful_decommission_timeout": "600s"} + update_cluster_args = { + 'location': GCP_LOCATION, + 'project_id': GCP_PROJECT, + 'cluster_name': CLUSTER_NAME, + 'cluster': CLUSTER, + 'update_mask': UPDATE_MASK, + 'graceful_decommission_timeout': cluster_decommission_timeout, + 'request_id': REQUEST_ID, + 'retry': RETRY, + 'timeout': TIMEOUT, + 'metadata': METADATA, + } + expected_calls = self.extra_links_expected_calls_base + [ + call.hook().update_cluster(**update_cluster_args) + ] + op = DataprocUpdateClusterOperator( task_id=TASK_ID, location=GCP_LOCATION, @@ -664,7 +960,7 @@ def test_execute(self, mock_hook): cluster=CLUSTER, update_mask=UPDATE_MASK, request_id=REQUEST_ID, - graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + graceful_decommission_timeout=cluster_decommission_timeout, project_id=GCP_PROJECT, gcp_conn_id=GCP_CONN_ID, retry=RETRY, @@ -672,21 +968,71 @@ def test_execute(self, mock_hook): metadata=METADATA, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=self.mock_context) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_hook.return_value.update_cluster.assert_called_once_with( + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) + + # Test whether the xcom push happens before updating the cluster + self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) + + self.mock_ti.xcom_push.assert_called_once_with( + key="cluster_conf", + value=DATAPROC_CLUSTER_CONF_EXPECTED, + execution_date=None, + ) + + def test_operator_extra_links(self): + op = DataprocUpdateClusterOperator( + task_id=TASK_ID, location=GCP_LOCATION, - project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME, cluster=CLUSTER, update_mask=UPDATE_MASK, graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, - request_id=REQUEST_ID, - retry=RETRY, - timeout=TIMEOUT, - metadata=METADATA, + project_id=GCP_PROJECT, + gcp_conn_id=GCP_CONN_ID, + dag=self.dag, ) + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized_dag + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocClusterLink": {}}], + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocClusterLink) + + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + + # Assert operator link is empty when no XCom push occured + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), "") + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + "", + ) + + ti.xcom_push(key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED) + + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + # Assert operator links after execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocClusterLink.name), + DATAPROC_CLUSTER_LINK_EXPECTED, + ) + # Check for negative case + self.assertEqual(op.get_extra_links(datetime(2020, 7, 20), DataprocClusterLink.name), "") + class TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -787,7 +1133,7 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -846,7 +1192,7 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -911,7 +1257,7 @@ def test_execute(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION @@ -937,7 +1283,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id="other-project", job=self.other_project_job, location=GCP_LOCATION @@ -963,12 +1309,14 @@ def test_builder(self, mock_hook, mock_uuid): assert self.job == job -class TestDataProcSparkOperator(unittest.TestCase): +class TestDataProcSparkOperator(DataprocJobTestBase): main_class = "org.apache.spark.examples.SparkPi" jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"] - job_id = "uuid_id" job = { - "reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id}, + "reference": { + "project_id": GCP_PROJECT, + "job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID, + }, "placement": {"cluster_name": "cluster-1"}, "labels": {"airflow-version": AIRFLOW_VERSION}, "spark_job": {"jar_file_uris": jars, "main_class": main_class}, @@ -985,9 +1333,11 @@ def test_deprecation_warning(self, mock_hook): @mock.patch(DATAPROC_PATH.format("uuid.uuid4")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute(self, mock_hook, mock_uuid): - mock_uuid.return_value = self.job_id + mock_uuid.return_value = TEST_JOB_ID mock_hook.return_value.project_id = GCP_PROJECT - mock_uuid.return_value = self.job_id + mock_uuid.return_value = TEST_JOB_ID + mock_hook.return_value.submit_job.return_value.reference.job_id = TEST_JOB_ID + self.extra_links_manager_mock.attach_mock(mock_hook, 'hook') op = DataprocSubmitSparkJobOperator( task_id=TASK_ID, @@ -999,6 +1349,68 @@ def test_execute(self, mock_hook, mock_uuid): job = op.generate_job() assert self.job == job + op.execute(context=self.mock_context) + self.mock_ti.xcom_push.assert_called_once_with( + key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None + ) + + # Test whether xcom push occurs before polling for job + self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls, any_order=False) + + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_operator_extra_links(self, mock_hook): + mock_hook.return_value.project_id = GCP_PROJECT + + op = DataprocSubmitSparkJobOperator( + task_id=TASK_ID, + region=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + main_class=self.main_class, + dataproc_jars=self.jars, + dag=self.dag, + ) + + serialized_dag = SerializedDAG.to_dict(self.dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) + deserialized_task = deserialized_dag.task_dict[TASK_ID] + + # Assert operator links for serialized DAG + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{"airflow.providers.google.cloud.operators.dataproc.DataprocJobLink": {}}], + ) + + # Assert operator link types are preserved during deserialization + self.assertIsInstance(deserialized_task.operator_extra_links[0], DataprocJobLink) + + ti = TaskInstance(task=op, execution_date=DEFAULT_DATE) + + # Assert operator link is empty when no XCom push occured + self.assertEqual(op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + + # Assert operator link is empty for deserialized task when no XCom push occured + self.assertEqual(deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), "") + + ti.xcom_push(key="job_conf", value=DATAPROC_JOB_CONF_EXPECTED) + + # Assert operator links after task execution + self.assertEqual( + op.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, + ) + + # Assert operator links are preserved in deserialized tasks + self.assertEqual( + deserialized_task.get_extra_links(DEFAULT_DATE, DataprocJobLink.name), + DATAPROC_JOB_LINK_EXPECTED, + ) + + # Assert for negative case + self.assertEqual( + deserialized_task.get_extra_links(datetime(2020, 7, 20), DataprocJobLink.name), + "", + ) + class TestDataProcHadoopOperator(unittest.TestCase): args = ["wordcount", "gs://pub/shakespeare/rose.txt"]