-
Notifications
You must be signed in to change notification settings - Fork 14.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test_connection method to GoogleBaseHook #24682
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
import google.auth.credentials | ||
import google.oauth2.service_account | ||
import google_auth_httplib2 | ||
import requests | ||
import tenacity | ||
from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests | ||
from google.api_core.gapic_v1.client_info import ClientInfo | ||
|
@@ -270,7 +271,12 @@ def _get_credentials(self) -> google.auth.credentials.Credentials: | |
|
||
def _get_access_token(self) -> str: | ||
"""Returns a valid access token from Google API Credentials""" | ||
return self._get_credentials().token | ||
credentials = self._get_credentials() | ||
auth_req = google.auth.transport.requests.Request() | ||
# credentials.token is None | ||
# Need to refresh credentials to populate the token | ||
credentials.refresh(auth_req) | ||
return credentials.token | ||
|
||
@functools.lru_cache(maxsize=None) | ||
def _get_credentials_email(self) -> str: | ||
|
@@ -580,3 +586,19 @@ def download_content_from_request(file_handle, request: dict, chunk_size: int) - | |
while done is False: | ||
_, done = downloader.next_chunk() | ||
file_handle.flush() | ||
|
||
def test_connection(self): | ||
phanikumv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Test the Google cloud connectivity from UI""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: "Google Cloud" |
||
status, message = False, '' | ||
try: | ||
token = self._get_access_token() | ||
url = f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't look safe, as we do not have url encoding here. |
||
response = requests.post(url) | ||
if response.status_code == 200: | ||
status = True | ||
message = 'Connection successfully tested' | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK catching all type of exceptions is bad practice. |
||
status = False | ||
message = str(e) | ||
|
||
return status, message |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -341,6 +341,24 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a | |
) | ||
assert ('CREDENTIALS', 'PROJECT_ID') == result | ||
|
||
@mock.patch('requests.post') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good practice is to mock with autospec=True to catch passing incorrect parameters to method. |
||
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id') | ||
def test_connection_success(self, mock_get_creds_and_proj_id, requests_post): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for the consistency sake I would name "requests_post" -> "mock_requests_post" |
||
requests_post.return_value.status_code = 200 | ||
credentials = mock.MagicMock() | ||
type(credentials).token = mock.PropertyMock(return_value="TOKEN") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should work instead: |
||
mock_get_creds_and_proj_id.return_value = (credentials, "PROJECT_ID") | ||
self.instance.extras = {} | ||
result = self.instance.test_connection() | ||
assert result == (True, 'Connection successfully tested') | ||
|
||
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id') | ||
def test_connection_failure(self, mock_get_creds_and_proj_id): | ||
mock_get_creds_and_proj_id.side_effect = AirflowException('Invalid key JSON.') | ||
self.instance.extras = {} | ||
result = self.instance.test_connection() | ||
assert result == (False, 'Invalid key JSON.') | ||
|
||
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id') | ||
def test_get_credentials_and_project_id_with_service_account_file(self, mock_get_creds_and_proj_id): | ||
mock_credentials = mock.MagicMock() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How was it working before the change if it was returning
None
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method
_get_access_token()
wasn't being used anywhere earlier. A sample dump of the Credentials object just after callingself._get_credentials()
is :{'token': None, 'expiry': None, '_quota_project_id': None, '_scopes': ('https://www.googleapis.com/auth/cloud-platform',), '_default_scopes': None, '_signer': <google.auth.crypt._cryptography_rsa.RSASigner object at 0x40359f3af0>, '_service_account_email': 'phani-svc-account@xxxxxxxx.iam.gserviceaccount.com', '_subject': None, '_project_id': 'xxxxxx-providers', '_token_uri': 'https://oauth2.googleapis.com/token', '_always_use_jwt_access': False, '_jwt_credentials': None, '_additional_claims': {}}