Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Databricks Deferrable Operators #19736

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
or the ``api/2.1/jobs/runs/submit``
`endpoint <https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit>`_.
"""
import json
from typing import Any, Dict, List, Optional

from requests import exceptions as requests_exceptions
Expand Down Expand Up @@ -92,6 +93,13 @@ def __eq__(self, other: object) -> bool:
def __repr__(self) -> str:
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, data: str) -> 'RunState':
return RunState(**json.loads(data))


class DatabricksHook(BaseDatabricksHook):
"""
Expand Down Expand Up @@ -198,6 +206,16 @@ def get_run_page_url(self, run_id: int) -> str:
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

async def a_get_run_page_url(self, run_id: int) -> str:
"""
Async version of `get_run_page_url()`.
:param run_id: id of the run
:return: URL of the run page
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
return response['run_page_url']

def get_job_id(self, run_id: int) -> int:
"""
Retrieves job_id from run_id.
Expand Down Expand Up @@ -229,6 +247,17 @@ def get_run_state(self, run_id: int) -> RunState:
state = response['state']
return RunState(**state)

async def a_get_run_state(self, run_id: int) -> RunState:
"""
Async version of `get_run_state()`.
:param run_id: id of the run
:return: state of the run
"""
json = {'run_id': run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
state = response['state']
return RunState(**state)

def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
Expand Down
237 changes: 224 additions & 13 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse

import aiohttp
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
from requests.exceptions import JSONDecodeError
from tenacity import RetryError, Retrying, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)

from airflow import __version__
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -135,6 +143,14 @@ def host(self) -> str:

return host

async def __aenter__(self):
self._session = aiohttp.ClientSession()
return self

async def __aexit__(self, *err):
await self._session.close()
self._session = None

@staticmethod
def _parse_host(host: str) -> str:
"""
Expand Down Expand Up @@ -169,6 +185,13 @@ def _get_retry_object(self) -> Retrying:
"""
return Retrying(**self.retry_args)

def _a_get_retry_object(self) -> AsyncRetrying:
"""
Instantiates an async retry object
:return: instance of AsyncRetrying class
"""
return AsyncRetrying(**self.retry_args)

def _get_aad_token(self, resource: str) -> str:
"""
Function to get AAD token for given resource. Supports managed identity or service principal auth
Expand Down Expand Up @@ -234,6 +257,72 @@ def _get_aad_token(self, resource: str) -> str:

return token

async def _a_get_aad_token(self, resource: str) -> str:
"""
Async version of `_get_aad_token()`.
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
return aad_token['token']

self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
try:
async for attempt in self._a_get_retry_object():
with attempt:
if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
params = {
"api-version": "2018-02-01",
"resource": resource,
}
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
else:
tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
"resource": resource,
"client_secret": self.databricks_conn.password,
}
azure_ad_endpoint = self.databricks_conn.extra_dejson.get(
"azure_ad_endpoint", AZURE_DEFAULT_AD_ENDPOINT
)
async with self._session.post(
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**USER_AGENT_HEADER,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
if (
'access_token' not in jsn
or jsn.get('token_type') != 'Bearer'
or 'expires_on' not in jsn
):
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")

token = jsn['access_token']
self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
break
except RetryError:
raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')
except aiohttp.ClientResponseError as err:
raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')

return token

def _get_aad_headers(self) -> dict:
"""
Fills AAD headers if necessary (SPN is outside of the workspace)
Expand All @@ -248,6 +337,20 @@ def _get_aad_headers(self) -> dict:
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

async def _a_get_aad_headers(self) -> dict:
"""
Async version of `_get_aad_headers()`.
:return: dictionary with filled AAD headers
"""
headers = {}
if 'azure_resource_id' in self.databricks_conn.extra_dejson:
mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
'azure_resource_id'
]
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
return headers

@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
Expand Down Expand Up @@ -281,6 +384,23 @@ def _check_azure_metadata_service() -> None:
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

async def _a_check_azure_metadata_service(self):
"""Async version of `_check_azure_metadata_service()`."""
try:
async with self._session.get(
url=AZURE_METADATA_SERVICE_INSTANCE_URL,
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
timeout=2,
) as resp:
jsn = await resp.json()
if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")

def _get_token(self, raise_error: bool = False) -> Optional[str]:
if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
Expand All @@ -304,6 +424,29 @@ def _get_token(self, raise_error: bool = False) -> Optional[str]:

return None

async def _a_get_token(self, raise_error: bool = False) -> Optional[str]:
if 'token' in self.databricks_conn.extra_dejson:
self.log.info(
'Using token auth. For security reasons, please set token in Password field instead of extra'
)
return self.databricks_conn.extra_dejson["token"]
elif not self.databricks_conn.login and self.databricks_conn.password:
self.log.info('Using token auth.')
return self.databricks_conn.password
elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
self.log.info('Using AAD Token for SPN.')
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
self.log.info('Using AAD Token for managed identity.')
await self._a_check_azure_metadata_service()
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif raise_error:
raise AirflowException('Token authentication isn\'t configured')

return None

def _log_request_error(self, attempt_num: int, error: str) -> None:
self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)

Expand Down Expand Up @@ -374,6 +517,55 @@ def _do_api_call(
else:
raise e

async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, Any]] = None):
"""
Async version of `_do_api_call()`.
:param endpoint_info: Tuple of method and endpoint
:param json: Parameters for this API call.
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise, throw an AirflowException.
"""
method, endpoint = endpoint_info

url = f'https://{self.host}/{endpoint}'

aad_headers = await self._a_get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}

auth: aiohttp.BasicAuth
token = await self._a_get_token()
if token:
auth = BearerAuth(token)
else:
self.log.info('Using basic auth.')
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)

request_func: Any
if method == 'GET':
request_func = self._session.get
elif method == 'POST':
request_func = self._session.post
elif method == 'PATCH':
request_func = self._session.patch
else:
raise AirflowException('Unexpected HTTP Method: ' + method)
try:
async for attempt in self._a_get_retry_object():
with attempt:
async with request_func(
url,
json=json,
auth=auth,
headers={**headers, **USER_AGENT_HEADER},
timeout=self.timeout_seconds,
) as response:
response.raise_for_status()
return await response.json()
except RetryError:
raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')
except aiohttp.ClientResponseError as err:
raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')

@staticmethod
def _get_error_code(exception: BaseException) -> str:
if isinstance(exception, requests_exceptions.HTTPError):
Expand All @@ -387,19 +579,25 @@ def _get_error_code(exception: BaseException) -> str:

@staticmethod
def _retryable_error(exception: BaseException) -> bool:
if not isinstance(exception, requests_exceptions.RequestException):
return False
return isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (
exception.response.status_code >= 500
or exception.response.status_code == 429
or (
exception.response.status_code == 400
and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
if isinstance(exception, requests_exceptions.RequestException):
if isinstance(exception, (requests_exceptions.ConnectionError, requests_exceptions.Timeout)) or (
exception.response is not None
and (
exception.response.status_code >= 500
or exception.response.status_code == 429
or (
exception.response.status_code == 400
and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
)
)
)
)
):
return True

if isinstance(exception, aiohttp.ClientResponseError):
if exception.status >= 500 or exception.status == 429:
return True

return False


class _TokenAuth(AuthBase):
Expand All @@ -414,3 +612,16 @@ def __init__(self, token: str) -> None:
def __call__(self, r: PreparedRequest) -> PreparedRequest:
r.headers['Authorization'] = 'Bearer ' + self.token
return r


class BearerAuth(aiohttp.BasicAuth):
"""aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth."""

def __new__(cls, token: str) -> 'BearerAuth':
return super().__new__(cls, token) # type: ignore

def __init__(self, token: str) -> None:
self.token = token

def encode(self) -> str:
return f'Bearer {self.token}'
Loading